diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/drop_path.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/drop_path.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34bd6b4270fa854ed3eeb8aba51cdd710575d2ce Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/drop_path.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/filtering.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/filtering.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..208cc78dcb1f724604e1b770e57f814e7649faa3 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/filtering.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/gmm.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/gmm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1703fe41689005f963661edacd0d55d85ebb4e26 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/gmm.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/spatial_transforms.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/spatial_transforms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29e926c3af77a3182776493e5e12a6b4469c3f29 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/spatial_transforms.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/__init__.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dd45c884efcf1b347ce94fda2fd3a1805a0247d Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/__init__.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/ahnet.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/ahnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ab6ba14ba20c634736aff7336e42b7698aa8f4f Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/ahnet.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/classifier.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/classifier.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a664057a0efcb413c10689f2d2562b591ed23dcc Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/classifier.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/densenet.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/densenet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27dda7ee5b8935de6172539d5fa5251a63f20825 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/densenet.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/dints.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/dints.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..901d6674e81a2bdc60ed779452eabe2be4b150da Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/dints.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/generator.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/generator.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..336a878789c83871f7cf7ad77710437794befab4 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/generator.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/highresnet.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/highresnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa33c67c900a132691e9ba8e998e63da7d853cce Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/highresnet.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/milmodel.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/milmodel.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfffed4631b78a3b11805d56257afae625f31b7f Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/milmodel.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/transchex.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/transchex.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..417f4445b517260a61e71366f3333b0322da5eb9 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/transchex.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/vitautoenc.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/vitautoenc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af20468fc7bb61fc3c0f7fb5f52f171e22d90297 Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/vitautoenc.cpython-38.pyc differ diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_efficientnet_blocks.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_efficientnet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..a5a6f30b794819fbd65d57d528d718ab077ad13e --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_efficientnet_blocks.py @@ -0,0 +1,280 @@ +""" EfficientNet, MobileNetV3, etc Blocks + +Hacked together by / Copyright 2019, Ross Wightman +""" + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer + +__all__ = [ + 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual'] + + +def num_groups(group_size, channels): + if not group_size: # 0 or None + return 1 # normal conv with 1 group + else: + # NOTE group_size == 1 -> depthwise conv + assert channels % group_size == 0 + return channels // group_size + + +class SqueezeExcite(nn.Module): + """ Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family + + Args: + in_chs (int): input channels to layer + rd_ratio (float): ratio of squeeze reduction + act_layer (nn.Module): activation layer of containing block + gate_layer (Callable): attention gate function + force_act_layer (nn.Module): override block's activation fn if this is set/bound + rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs + """ + + def __init__( + self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU, + gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None): + super(SqueezeExcite, self).__init__() + if rd_channels is None: + rd_round_fn = rd_round_fn or round + rd_channels = rd_round_fn(in_chs * rd_ratio) + act_layer = force_act_layer or act_layer + self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True) + self.act1 = create_act_layer(act_layer, inplace=True) + self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + return x * self.gate(x_se) + + +class ConvBnAct(nn.Module): + """ Conv + Norm Layer + Activation w/ optional skip connection + """ + def __init__( + self, in_chs, out_chs, kernel_size, stride=1, dilation=1, group_size=0, pad_type='', + skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.): + super(ConvBnAct, self).__init__() + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + groups = num_groups(group_size, in_chs) + self.has_skip = skip and stride == 1 and in_chs == out_chs + + self.conv = create_conv2d( + in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type) + self.bn1 = norm_act_layer(out_chs, inplace=True) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() + + def feature_info(self, location): + if location == 'expansion': # output of conv after act, same as block coutput + return dict(module='bn1', hook_type='forward', num_chs=self.conv.out_channels) + else: # location == 'bottleneck', block output + return dict(module='', num_chs=self.conv.out_channels) + + def forward(self, x): + shortcut = x + x = self.conv(x) + x = self.bn1(x) + if self.has_skip: + x = self.drop_path(x) + shortcut + return x + + +class DepthwiseSeparableConv(nn.Module): + """ DepthwiseSeparable block + Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion + (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. + """ + def __init__( + self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', + noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + se_layer=None, drop_path_rate=0.): + super(DepthwiseSeparableConv, self).__init__() + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + groups = num_groups(group_size, in_chs) + self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip + self.has_pw_act = pw_act # activation after point-wise conv + + self.conv_dw = create_conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, groups=groups) + self.bn1 = norm_act_layer(in_chs, inplace=True) + + # Squeeze-and-excitation + self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity() + + self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) + self.bn2 = norm_act_layer(out_chs, inplace=True, apply_act=self.has_pw_act) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() + + def feature_info(self, location): + if location == 'expansion': # after SE, input to PW + return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) + else: # location == 'bottleneck', block output + return dict(module='', num_chs=self.conv_pw.out_channels) + + def forward(self, x): + shortcut = x + x = self.conv_dw(x) + x = self.bn1(x) + x = self.se(x) + x = self.conv_pw(x) + x = self.bn2(x) + if self.has_skip: + x = self.drop_path(x) + shortcut + return x + + +class InvertedResidual(nn.Module): + """ Inverted residual block w/ optional SE + + Originally used in MobileNet-V2 - https://arxiv.org/abs/1801.04381v4, this layer is often + referred to as 'MBConv' for (Mobile inverted bottleneck conv) and is also used in + * MNasNet - https://arxiv.org/abs/1807.11626 + * EfficientNet - https://arxiv.org/abs/1905.11946 + * MobileNet-V3 - https://arxiv.org/abs/1905.02244 + """ + + def __init__( + self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', + noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.): + super(InvertedResidual, self).__init__() + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + conv_kwargs = conv_kwargs or {} + mid_chs = make_divisible(in_chs * exp_ratio) + groups = num_groups(group_size, mid_chs) + self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + + # Point-wise expansion + self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) + self.bn1 = norm_act_layer(mid_chs, inplace=True) + + # Depth-wise convolution + self.conv_dw = create_conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, + groups=groups, padding=pad_type, **conv_kwargs) + self.bn2 = norm_act_layer(mid_chs, inplace=True) + + # Squeeze-and-excitation + self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() + + # Point-wise linear projection + self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) + self.bn3 = norm_act_layer(out_chs, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() + + def feature_info(self, location): + if location == 'expansion': # after SE, input to PWL + return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) + else: # location == 'bottleneck', block output + return dict(module='', num_chs=self.conv_pwl.out_channels) + + def forward(self, x): + shortcut = x + x = self.conv_pw(x) + x = self.bn1(x) + x = self.conv_dw(x) + x = self.bn2(x) + x = self.se(x) + x = self.conv_pwl(x) + x = self.bn3(x) + if self.has_skip: + x = self.drop_path(x) + shortcut + return x + + +class CondConvResidual(InvertedResidual): + """ Inverted residual block w/ CondConv routing""" + + def __init__( + self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', + noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.): + + self.num_experts = num_experts + conv_kwargs = dict(num_experts=self.num_experts) + + super(CondConvResidual, self).__init__( + in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, group_size=group_size, + pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs, + drop_path_rate=drop_path_rate) + + self.routing_fn = nn.Linear(in_chs, self.num_experts) + + def forward(self, x): + shortcut = x + pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) # CondConv routing + routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) + x = self.conv_pw(x, routing_weights) + x = self.bn1(x) + x = self.conv_dw(x, routing_weights) + x = self.bn2(x) + x = self.se(x) + x = self.conv_pwl(x, routing_weights) + x = self.bn3(x) + if self.has_skip: + x = self.drop_path(x) + shortcut + return x + + +class EdgeResidual(nn.Module): + """ Residual block with expansion convolution followed by pointwise-linear w/ stride + + Originally introduced in `EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML` + - https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html + + This layer is also called FusedMBConv in the MobileDet, EfficientNet-X, and EfficientNet-V2 papers + * MobileDet - https://arxiv.org/abs/2004.14525 + * EfficientNet-X - https://arxiv.org/abs/2102.05610 + * EfficientNet-V2 - https://arxiv.org/abs/2104.00298 + """ + + def __init__( + self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, group_size=0, pad_type='', + force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): + super(EdgeResidual, self).__init__() + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + if force_in_chs > 0: + mid_chs = make_divisible(force_in_chs * exp_ratio) + else: + mid_chs = make_divisible(in_chs * exp_ratio) + groups = num_groups(group_size, in_chs) + self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + + # Expansion convolution + self.conv_exp = create_conv2d( + in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type) + self.bn1 = norm_act_layer(mid_chs, inplace=True) + + # Squeeze-and-excitation + self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() + + # Point-wise linear projection + self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) + self.bn2 = norm_act_layer(out_chs, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() + + def feature_info(self, location): + if location == 'expansion': # after SE, before PWL + return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) + else: # location == 'bottleneck', block output + return dict(module='', num_chs=self.conv_pwl.out_channels) + + def forward(self, x): + shortcut = x + x = self.conv_exp(x) + x = self.bn1(x) + x = self.se(x) + x = self.conv_pwl(x) + x = self.bn2(x) + if self.has_skip: + x = self.drop_path(x) + shortcut + return x diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_manipulate.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_manipulate.py new file mode 100644 index 0000000000000000000000000000000000000000..e689b39276650b3de9d6e8e2395a6057f6873074 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_manipulate.py @@ -0,0 +1,278 @@ +import collections.abc +import math +import re +from collections import defaultdict +from itertools import chain +from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union + +import torch +from torch import nn as nn +from torch.utils.checkpoint import checkpoint + +__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv', + 'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq'] + + +def model_parameters(model: nn.Module, exclude_head: bool = False): + if exclude_head: + # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering + return [p for p in model.parameters()][:-2] + else: + return model.parameters() + + +def named_apply( + fn: Callable, + module: nn.Module, name='', + depth_first: bool = True, + include_root: bool = False, +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def named_modules( + module: nn.Module, + name: str = '', + depth_first: bool = True, + include_root: bool = False, +): + if not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + yield name, module + + +def named_modules_with_params( + module: nn.Module, + name: str = '', + depth_first: bool = True, + include_root: bool = False, +): + if module._parameters and not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules_with_params( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if module._parameters and depth_first and include_root: + yield name, module + + +MATCH_PREV_GROUP = (99999,) + + +def group_with_matcher( + named_objects: Iterator[Tuple[str, Any]], + group_matcher: Union[Dict, Callable], + return_values: bool = False, + reverse: bool = False +): + if isinstance(group_matcher, dict): + # dictionary matcher contains a dict of raw-string regex expr that must be compiled + compiled = [] + for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()): + if mspec is None: + continue + # map all matching specifications into 3-tuple (compiled re, prefix, suffix) + if isinstance(mspec, (tuple, list)): + # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix) + for sspec in mspec: + compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])] + else: + compiled += [(re.compile(mspec), (group_ordinal,), None)] + group_matcher = compiled + + def _get_grouping(name): + if isinstance(group_matcher, (list, tuple)): + for match_fn, prefix, suffix in group_matcher: + r = match_fn.match(name) + if r: + parts = (prefix, r.groups(), suffix) + # map all tuple elem to int for numeric sort, filter out None entries + return tuple(map(float, chain.from_iterable(filter(None, parts)))) + return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal + else: + ord = group_matcher(name) + if not isinstance(ord, collections.abc.Iterable): + return ord, + return tuple(ord) + + # map layers into groups via ordinals (ints or tuples of ints) from matcher + grouping = defaultdict(list) + for k, v in named_objects: + grouping[_get_grouping(k)].append(v if return_values else k) + + # remap to integers + layer_id_to_param = defaultdict(list) + lid = -1 + for k in sorted(filter(lambda x: x is not None, grouping.keys())): + if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]: + lid += 1 + layer_id_to_param[lid].extend(grouping[k]) + + if reverse: + assert not return_values, "reverse mapping only sensible for name output" + # output reverse mapping + param_to_layer_id = {} + for lid, lm in layer_id_to_param.items(): + for n in lm: + param_to_layer_id[n] = lid + return param_to_layer_id + + return layer_id_to_param + + +def group_parameters( + module: nn.Module, + group_matcher, + return_values: bool = False, + reverse: bool = False, +): + return group_with_matcher( + module.named_parameters(), group_matcher, return_values=return_values, reverse=reverse) + + +def group_modules( + module: nn.Module, + group_matcher, + return_values: bool = False, + reverse: bool = False, +): + return group_with_matcher( + named_modules_with_params(module), group_matcher, return_values=return_values, reverse=reverse) + + +def flatten_modules( + named_modules: Iterator[Tuple[str, nn.Module]], + depth: int = 1, + prefix: Union[str, Tuple[str, ...]] = '', + module_types: Union[str, Tuple[Type[nn.Module]]] = 'sequential', +): + prefix_is_tuple = isinstance(prefix, tuple) + if isinstance(module_types, str): + if module_types == 'container': + module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict) + else: + module_types = (nn.Sequential,) + for name, module in named_modules: + if depth and isinstance(module, module_types): + yield from flatten_modules( + module.named_children(), + depth - 1, + prefix=(name,) if prefix_is_tuple else name, + module_types=module_types, + ) + else: + if prefix_is_tuple: + name = prefix + (name,) + yield name, module + else: + if prefix: + name = '.'.join([prefix, name]) + yield name, module + + +def checkpoint_seq( + functions, + x, + every=1, + flatten=False, + skip_last=False, + preserve_rng_state=True +): + r"""A helper function for checkpointing sequential models. + + Sequential models execute a list of modules/functions in order + (sequentially). Therefore, we can divide such a sequence into segments + and checkpoint each segment. All segments except run in :func:`torch.no_grad` + manner, i.e., not storing the intermediate activations. The inputs of each + checkpointed segment will be saved for re-running the segment in the backward pass. + + See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. + + .. warning:: + Checkpointing currently only supports :func:`torch.autograd.backward` + and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` + is not supported. + + .. warning: + At least one of the inputs needs to have :code:`requires_grad=True` if + grads are needed for model inputs, otherwise the checkpointed part of the + model won't have gradients. + + Args: + functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. + x: A Tensor that is input to :attr:`functions` + every: checkpoint every-n functions (default: 1) + flatten (bool): flatten nn.Sequential of nn.Sequentials + skip_last (bool): skip checkpointing the last function in the sequence if True + preserve_rng_state (bool, optional, default=True): Omit stashing and restoring + the RNG state during each checkpoint. + + Returns: + Output of running :attr:`functions` sequentially on :attr:`*inputs` + + Example: + >>> model = nn.Sequential(...) + >>> input_var = checkpoint_seq(model, input_var, every=2) + """ + def run_function(start, end, functions): + def forward(_x): + for j in range(start, end + 1): + _x = functions[j](_x) + return _x + return forward + + if isinstance(functions, torch.nn.Sequential): + functions = functions.children() + if flatten: + functions = chain.from_iterable(functions) + if not isinstance(functions, (tuple, list)): + functions = tuple(functions) + + num_checkpointed = len(functions) + if skip_last: + num_checkpointed -= 1 + end = -1 + for start in range(0, num_checkpointed, every): + end = min(start + every - 1, num_checkpointed - 1) + x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state) + if skip_last: + return run_function(end + 1, len(functions) - 1, functions)(x) + return x + + +def adapt_input_conv(in_chans, conv_weight): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError('Weight format not supported by conversion.') + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + conv_weight = conv_weight.to(conv_type) + return conv_weight diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_pretrained.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..2938f8fe714d9bcfe5477b6f083defc31ee8e66a --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_pretrained.py @@ -0,0 +1,94 @@ +import copy +from collections import deque, defaultdict +from dataclasses import dataclass, field, replace, asdict +from typing import Any, Deque, Dict, Tuple, Optional, Union + + +__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg'] + + +@dataclass +class PretrainedCfg: + """ + """ + # weight source locations + url: Optional[Union[str, Tuple[str, str]]] = None # remote URL + file: Optional[str] = None # local / shared filesystem path + state_dict: Optional[Dict[str, Any]] = None # in-memory state dict + hf_hub_id: Optional[str] = None # Hugging Face Hub model id ('organization/model') + hf_hub_filename: Optional[str] = None # Hugging Face Hub filename (overrides default) + + source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub) + architecture: Optional[str] = None # architecture variant can be set when not implicit + tag: Optional[str] = None # pretrained tag of source + custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files) + + # input / data config + input_size: Tuple[int, int, int] = (3, 224, 224) + test_input_size: Optional[Tuple[int, int, int]] = None + min_input_size: Optional[Tuple[int, int, int]] = None + fixed_input_size: bool = False + interpolation: str = 'bicubic' + crop_pct: float = 0.875 + test_crop_pct: Optional[float] = None + crop_mode: str = 'center' + mean: Tuple[float, ...] = (0.485, 0.456, 0.406) + std: Tuple[float, ...] = (0.229, 0.224, 0.225) + + # head / classifier config and meta-data + num_classes: int = 1000 + label_offset: Optional[int] = None + label_names: Optional[Tuple[str]] = None + label_descriptions: Optional[Dict[str, str]] = None + + # model attributes that vary with above or required for pretrained adaptation + pool_size: Optional[Tuple[int, ...]] = None + test_pool_size: Optional[Tuple[int, ...]] = None + first_conv: Optional[str] = None + classifier: Optional[str] = None + + license: Optional[str] = None + description: Optional[str] = None + origin_url: Optional[str] = None + paper_name: Optional[str] = None + paper_ids: Optional[Union[str, Tuple[str]]] = None + notes: Optional[Tuple[str]] = None + + @property + def has_weights(self): + return self.url or self.file or self.hf_hub_id + + def to_dict(self, remove_source=False, remove_null=True): + return filter_pretrained_cfg( + asdict(self), + remove_source=remove_source, + remove_null=remove_null + ) + + +def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True): + filtered_cfg = {} + keep_null = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none + for k, v in cfg.items(): + if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}: + continue + if remove_null and v is None and k not in keep_null: + continue + filtered_cfg[k] = v + return filtered_cfg + + +@dataclass +class DefaultCfg: + tags: Deque[str] = field(default_factory=deque) # priority queue of tags (first is default) + cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict) # pretrained cfgs by tag + is_pretrained: bool = False # at least one of the configs has a pretrained source set + + @property + def default(self): + return self.cfgs[self.tags[0]] + + @property + def default_with_tag(self): + tag = self.tags[0] + return tag, self.cfgs[tag] diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/convmixer.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/convmixer.py new file mode 100644 index 0000000000000000000000000000000000000000..854f84a07f7c51ab9fe2cd7df47ae9e9f8d1e0e8 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/convmixer.py @@ -0,0 +1,144 @@ +""" ConvMixer + +""" +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import SelectAdaptivePool2d +from ._registry import register_model, generate_default_cfgs +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq + +__all__ = ['ConvMixer'] + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + x + + +class ConvMixer(nn.Module): + def __init__( + self, + dim, + depth, + kernel_size=9, + patch_size=7, + in_chans=3, + num_classes=1000, + global_pool='avg', + drop_rate=0., + act_layer=nn.GELU, + **kwargs, + ): + super().__init__() + self.num_classes = num_classes + self.num_features = dim + self.grad_checkpointing = False + + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size), + act_layer(), + nn.BatchNorm2d(dim) + ) + self.blocks = nn.Sequential( + *[nn.Sequential( + Residual(nn.Sequential( + nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), + act_layer(), + nn.BatchNorm2d(dim) + )), + nn.Conv2d(dim, dim, kernel_size=1), + act_layer(), + nn.BatchNorm2d(dim) + ) for i in range(depth)] + ) + self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + self.head_drop = nn.Dropout(drop_rate) + self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity() + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict(stem=r'^stem', blocks=r'^blocks\.(\d+)') + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + x = self.pooling(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_convmixer(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for ConvMixer models.') + + return build_model_with_cfg(ConvMixer, variant, pretrained, **kwargs) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .96, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', + 'first_conv': 'stem.0', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'convmixer_1536_20.in1k': _cfg(hf_hub_id='timm/'), + 'convmixer_768_32.in1k': _cfg(hf_hub_id='timm/'), + 'convmixer_1024_20_ks9_p14.in1k': _cfg(hf_hub_id='timm/') +}) + + + +@register_model +def convmixer_1536_20(pretrained=False, **kwargs) -> ConvMixer: + model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs) + return _create_convmixer('convmixer_1536_20', pretrained, **model_args) + + +@register_model +def convmixer_768_32(pretrained=False, **kwargs) -> ConvMixer: + model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, act_layer=nn.ReLU, **kwargs) + return _create_convmixer('convmixer_768_32', pretrained, **model_args) + + +@register_model +def convmixer_1024_20_ks9_p14(pretrained=False, **kwargs) -> ConvMixer: + model_args = dict(dim=1024, depth=20, kernel_size=9, patch_size=14, **kwargs) + return _create_convmixer('convmixer_1024_20_ks9_p14', pretrained, **model_args) \ No newline at end of file diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/convnext.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..ce7fd20b24be0a194c54c1997589f462cba35d98 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/convnext.py @@ -0,0 +1,1098 @@ +""" ConvNeXt + +Papers: +* `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf +@Article{liu2022convnet, + author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie}, + title = {A ConvNet for the 2020s}, + journal = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2022}, +} + +* `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808 +@article{Woo2023ConvNeXtV2, + title={ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders}, + author={Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon and Saining Xie}, + year={2023}, + journal={arXiv preprint arXiv:2301.00808}, +} + +Original code and weights from: +* https://github.com/facebookresearch/ConvNeXt, original copyright below +* https://github.com/facebookresearch/ConvNeXt-V2, original copyright below + +Model defs atto, femto, pico, nano and _ols / _hnf variants are timm originals. + +Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman +""" +# ConvNeXt +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# This source code is licensed under the MIT license + +# ConvNeXt-V2 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)) +# No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially. + +from collections import OrderedDict +from functools import partial +from typing import Callable, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \ + LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple +from timm.layers import NormMlpClassifierHead, ClassifierHead +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._registry import generate_default_cfgs, register_model, register_model_deprecations + +__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this + + +class Downsample(nn.Module): + + def __init__(self, in_chs, out_chs, stride=1, dilation=1): + super().__init__() + avg_stride = stride if dilation == 1 else 1 + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + else: + self.pool = nn.Identity() + + if in_chs != out_chs: + self.conv = create_conv2d(in_chs, out_chs, 1, stride=1) + else: + self.conv = nn.Identity() + + def forward(self, x): + x = self.pool(x) + x = self.conv(x) + return x + + +class ConvNeXtBlock(nn.Module): + """ ConvNeXt Block + There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + + Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate + choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear + is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. + """ + + def __init__( + self, + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 7, + stride: int = 1, + dilation: Union[int, Tuple[int, int]] = (1, 1), + mlp_ratio: float = 4, + conv_mlp: bool = False, + conv_bias: bool = True, + use_grn: bool = False, + ls_init_value: Optional[float] = 1e-6, + act_layer: Union[str, Callable] = 'gelu', + norm_layer: Optional[Callable] = None, + drop_path: float = 0., + ): + """ + + Args: + in_chs: Block input channels. + out_chs: Block output channels (same as in_chs if None). + kernel_size: Depthwise convolution kernel size. + stride: Stride of depthwise convolution. + dilation: Tuple specifying input and output dilation of block. + mlp_ratio: MLP expansion ratio. + conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True. + conv_bias: Apply bias for all convolution (linear) layers. + use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2) + ls_init_value: Layer-scale init values, layer-scale applied if not None. + act_layer: Activation layer. + norm_layer: Normalization layer (defaults to LN if not specified). + drop_path: Stochastic depth probability. + """ + super().__init__() + out_chs = out_chs or in_chs + dilation = to_ntuple(2)(dilation) + act_layer = get_act_layer(act_layer) + if not norm_layer: + norm_layer = LayerNorm2d if conv_mlp else LayerNorm + mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp) + self.use_conv_mlp = conv_mlp + self.conv_dw = create_conv2d( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=dilation[0], + depthwise=True, + bias=conv_bias, + ) + self.norm = norm_layer(out_chs) + self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0]) + else: + self.shortcut = nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.conv_dw(x) + if self.use_conv_mlp: + x = self.norm(x) + x = self.mlp(x) + else: + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = self.mlp(x) + x = x.permute(0, 3, 1, 2) + if self.gamma is not None: + x = x.mul(self.gamma.reshape(1, -1, 1, 1)) + + x = self.drop_path(x) + self.shortcut(shortcut) + return x + + +class ConvNeXtStage(nn.Module): + + def __init__( + self, + in_chs, + out_chs, + kernel_size=7, + stride=2, + depth=2, + dilation=(1, 1), + drop_path_rates=None, + ls_init_value=1.0, + conv_mlp=False, + conv_bias=True, + use_grn=False, + act_layer='gelu', + norm_layer=None, + norm_layer_cl=None + ): + super().__init__() + self.grad_checkpointing = False + + if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]: + ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1 + pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used + self.downsample = nn.Sequential( + norm_layer(in_chs), + create_conv2d( + in_chs, + out_chs, + kernel_size=ds_ks, + stride=stride, + dilation=dilation[0], + padding=pad, + bias=conv_bias, + ), + ) + in_chs = out_chs + else: + self.downsample = nn.Identity() + + drop_path_rates = drop_path_rates or [0.] * depth + stage_blocks = [] + for i in range(depth): + stage_blocks.append(ConvNeXtBlock( + in_chs=in_chs, + out_chs=out_chs, + kernel_size=kernel_size, + dilation=dilation[1], + drop_path=drop_path_rates[i], + ls_init_value=ls_init_value, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + use_grn=use_grn, + act_layer=act_layer, + norm_layer=norm_layer if conv_mlp else norm_layer_cl, + )) + in_chs = out_chs + self.blocks = nn.Sequential(*stage_blocks) + + def forward(self, x): + x = self.downsample(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class ConvNeXt(nn.Module): + r""" ConvNeXt + A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf + """ + + def __init__( + self, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + output_stride: int = 32, + depths: Tuple[int, ...] = (3, 3, 9, 3), + dims: Tuple[int, ...] = (96, 192, 384, 768), + kernel_sizes: Union[int, Tuple[int, ...]] = 7, + ls_init_value: Optional[float] = 1e-6, + stem_type: str = 'patch', + patch_size: int = 4, + head_init_scale: float = 1., + head_norm_first: bool = False, + head_hidden_size: Optional[int] = None, + conv_mlp: bool = False, + conv_bias: bool = True, + use_grn: bool = False, + act_layer: Union[str, Callable] = 'gelu', + norm_layer: Optional[Union[str, Callable]] = None, + norm_eps: Optional[float] = None, + drop_rate: float = 0., + drop_path_rate: float = 0., + ): + """ + Args: + in_chans: Number of input image channels. + num_classes: Number of classes for classification head. + global_pool: Global pooling type. + output_stride: Output stride of network, one of (8, 16, 32). + depths: Number of blocks at each stage. + dims: Feature dimension at each stage. + kernel_sizes: Depthwise convolution kernel-sizes for each stage. + ls_init_value: Init value for Layer Scale, disabled if None. + stem_type: Type of stem. + patch_size: Stem patch size for patch stem. + head_init_scale: Init scaling value for classifier weights and biases. + head_norm_first: Apply normalization before global pool + head. + head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False. + conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last. + conv_bias: Use bias layers w/ all convolutions. + use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP. + act_layer: Activation layer type. + norm_layer: Normalization layer type. + drop_rate: Head pre-classifier dropout rate. + drop_path_rate: Stochastic depth drop rate. + """ + super().__init__() + assert output_stride in (8, 16, 32) + kernel_sizes = to_ntuple(4)(kernel_sizes) + if norm_layer is None: + norm_layer = LayerNorm2d + norm_layer_cl = norm_layer if conv_mlp else LayerNorm + if norm_eps is not None: + norm_layer = partial(norm_layer, eps=norm_eps) + norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) + else: + assert conv_mlp,\ + 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' + norm_layer_cl = norm_layer + if norm_eps is not None: + norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) + + self.num_classes = num_classes + self.drop_rate = drop_rate + self.feature_info = [] + + assert stem_type in ('patch', 'overlap', 'overlap_tiered') + if stem_type == 'patch': + # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias), + norm_layer(dims[0]), + ) + stem_stride = patch_size + else: + mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0] + self.stem = nn.Sequential( + nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias), + nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias), + norm_layer(dims[0]), + ) + stem_stride = 4 + + self.stages = nn.Sequential() + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + stages = [] + prev_chs = dims[0] + curr_stride = stem_stride + dilation = 1 + # 4 feature resolution stages, each consisting of multiple residual blocks + for i in range(4): + stride = 2 if curr_stride == 2 or i > 0 else 1 + if curr_stride >= output_stride and stride > 1: + dilation *= stride + stride = 1 + curr_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + out_chs = dims[i] + stages.append(ConvNeXtStage( + prev_chs, + out_chs, + kernel_size=kernel_sizes[i], + stride=stride, + dilation=(first_dilation, dilation), + depth=depths[i], + drop_path_rates=dp_rates[i], + ls_init_value=ls_init_value, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + use_grn=use_grn, + act_layer=act_layer, + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl, + )) + prev_chs = out_chs + # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 + self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) + self.num_features = prev_chs + + # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets + # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) + if head_norm_first: + assert not head_hidden_size + self.norm_pre = norm_layer(self.num_features) + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + ) + else: + self.norm_pre = nn.Identity() + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + hidden_size=head_hidden_size, + pool_type=global_pool, + drop_rate=self.drop_rate, + norm_layer=norm_layer, + act_layer='gelu', + ) + named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.downsample', (0,)), # blocks + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^norm_pre', (99999,)) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes=0, global_pool=None): + self.head.reset(num_classes, global_pool) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm_pre(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=True) if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _init_weights(module, name=None, head_init_scale=1.0): + if isinstance(module, nn.Conv2d): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + nn.init.zeros_(module.bias) + if name and 'head.' in name: + module.weight.data.mul_(head_init_scale) + module.bias.data.mul_(head_init_scale) + + +def checkpoint_filter_fn(state_dict, model): + """ Remap FB checkpoints -> timm """ + if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict: + return state_dict # non-FB checkpoint + if 'model' in state_dict: + state_dict = state_dict['model'] + + out_dict = {} + if 'visual.trunk.stem.0.weight' in state_dict: + out_dict = {k.replace('visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('visual.trunk.')} + if 'visual.head.proj.weight' in state_dict: + out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight'] + out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0]) + elif 'visual.head.mlp.fc1.weight' in state_dict: + out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight'] + out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias'] + out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight'] + out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0]) + return out_dict + + import re + for k, v in state_dict.items(): + k = k.replace('downsample_layers.0.', 'stem.') + k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) + k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k) + k = k.replace('dwconv', 'conv_dw') + k = k.replace('pwconv', 'mlp.fc') + if 'grn' in k: + k = k.replace('grn.beta', 'mlp.grn.bias') + k = k.replace('grn.gamma', 'mlp.grn.weight') + v = v.reshape(v.shape[-1]) + k = k.replace('head.', 'head.fc.') + if k.startswith('norm.'): + k = k.replace('norm', 'head.norm') + if v.ndim == 2 and 'head' not in k: + model_shape = model.state_dict()[k].shape + v = v.reshape(model_shape) + out_dict[k] = v + + return out_dict + + +def _create_convnext(variant, pretrained=False, **kwargs): + if kwargs.get('pretrained_cfg', '') == 'fcmae': + # NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`) + # This is workaround loading with num_classes=0 w/o removing norm-layer. + kwargs.setdefault('pretrained_strict', False) + + model = build_model_with_cfg( + ConvNeXt, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **kwargs) + return model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head.fc', + **kwargs + } + + +def _cfgv2(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head.fc', + 'license': 'cc-by-nc-4.0', 'paper_ids': 'arXiv:2301.00808', + 'paper_name': 'ConvNeXt-V2: Co-designing and Scaling ConvNets with Masked Autoencoders', + 'origin_url': 'https://github.com/facebookresearch/ConvNeXt-V2', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + # timm specific variants + 'convnext_tiny.in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_small.in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + + 'convnext_atto.d2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth', + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_atto_ols.a2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth', + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_femto.d1_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth', + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_femto_ols.d1_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth', + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_pico.d1_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth', + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_pico_ols.d1_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth', + hf_hub_id='timm/', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_nano.in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_nano.d1h_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth', + hf_hub_id='timm/', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_nano_ols.d1h_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth', + hf_hub_id='timm/', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_tiny_hnf.a2h_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', + hf_hub_id='timm/', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + + 'convnext_tiny.in12k_ft_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_small.in12k_ft_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + 'convnext_nano.in12k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, num_classes=11821), + 'convnext_tiny.in12k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, num_classes=11821), + 'convnext_small.in12k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, num_classes=11821), + + 'convnext_tiny.fb_in22k_ft_in1k': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth', + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_small.fb_in22k_ft_in1k': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth', + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_base.fb_in22k_ft_in1k': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth', + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_large.fb_in22k_ft_in1k': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth', + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_xlarge.fb_in22k_ft_in1k': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth', + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + + 'convnext_tiny.fb_in1k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_small.fb_in1k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_base.fb_in1k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_large.fb_in1k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + + 'convnext_tiny.fb_in22k_ft_in1k_384': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth', + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_small.fb_in22k_ft_in1k_384': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth', + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_base.fb_in22k_ft_in1k_384': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth', + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_large.fb_in22k_ft_in1k_384': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth', + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth', + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + 'convnext_tiny.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", + hf_hub_id='timm/', + num_classes=21841), + 'convnext_small.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", + hf_hub_id='timm/', + num_classes=21841), + 'convnext_base.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", + hf_hub_id='timm/', + num_classes=21841), + 'convnext_large.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", + hf_hub_id='timm/', + num_classes=21841), + 'convnext_xlarge.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", + hf_hub_id='timm/', + num_classes=21841), + + 'convnextv2_nano.fcmae_ft_in22k_in1k': _cfgv2( + url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt', + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnextv2_nano.fcmae_ft_in22k_in1k_384': _cfgv2( + url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt', + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnextv2_tiny.fcmae_ft_in22k_in1k': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt", + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnextv2_tiny.fcmae_ft_in22k_in1k_384': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt", + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnextv2_base.fcmae_ft_in22k_in1k': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt", + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnextv2_base.fcmae_ft_in22k_in1k_384': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt", + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnextv2_large.fcmae_ft_in22k_in1k': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt", + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnextv2_large.fcmae_ft_in22k_in1k_384': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt", + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnextv2_huge.fcmae_ft_in22k_in1k_384': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt", + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnextv2_huge.fcmae_ft_in22k_in1k_512': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt", + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(15, 15), crop_pct=1.0, crop_mode='squash'), + + 'convnextv2_atto.fcmae_ft_in1k': _cfgv2( + url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt', + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnextv2_femto.fcmae_ft_in1k': _cfgv2( + url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt', + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnextv2_pico.fcmae_ft_in1k': _cfgv2( + url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt', + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnextv2_nano.fcmae_ft_in1k': _cfgv2( + url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt', + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnextv2_tiny.fcmae_ft_in1k': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt", + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnextv2_base.fcmae_ft_in1k': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt", + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnextv2_large.fcmae_ft_in1k': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt", + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnextv2_huge.fcmae_ft_in1k': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt", + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + + 'convnextv2_atto.fcmae': _cfgv2( + url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_atto_1k_224_fcmae.pt', + hf_hub_id='timm/', + num_classes=0), + 'convnextv2_femto.fcmae': _cfgv2( + url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_femto_1k_224_fcmae.pt', + hf_hub_id='timm/', + num_classes=0), + 'convnextv2_pico.fcmae': _cfgv2( + url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_pico_1k_224_fcmae.pt', + hf_hub_id='timm/', + num_classes=0), + 'convnextv2_nano.fcmae': _cfgv2( + url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_nano_1k_224_fcmae.pt', + hf_hub_id='timm/', + num_classes=0), + 'convnextv2_tiny.fcmae': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_tiny_1k_224_fcmae.pt", + hf_hub_id='timm/', + num_classes=0), + 'convnextv2_base.fcmae': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_base_1k_224_fcmae.pt", + hf_hub_id='timm/', + num_classes=0), + 'convnextv2_large.fcmae': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt", + hf_hub_id='timm/', + num_classes=0), + 'convnextv2_huge.fcmae': _cfgv2( + url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt", + hf_hub_id='timm/', + num_classes=0), + + 'convnextv2_small.untrained': _cfg(), + + # CLIP weights, fine-tuned on in1k or in12k + in1k + 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0), + 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k_384': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_320': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0), + 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + 'convnext_base.clip_laion2b_augreg_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0), + 'convnext_base.clip_laiona_augreg_ft_in1k_384': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0 + ), + 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k_384': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash' + ), + 'convnext_xxlarge.clip_laion2b_soup_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0), + + 'convnext_base.clip_laion2b_augreg_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0), + 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_320': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821, + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0), + 'convnext_large_mlp.clip_laion2b_augreg_ft_in12k_384': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821, + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_384': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821, + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_xxlarge.clip_laion2b_soup_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0), + + # CLIP original image tower weights + 'convnext_base.clip_laion2b': _cfg( + hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640), + 'convnext_base.clip_laion2b_augreg': _cfg( + hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640), + 'convnext_base.clip_laiona': _cfg( + hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640), + 'convnext_base.clip_laiona_320': _cfg( + hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640), + 'convnext_base.clip_laiona_augreg_320': _cfg( + hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640), + 'convnext_large_mlp.clip_laion2b_augreg': _cfg( + hf_hub_id='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768), + 'convnext_large_mlp.clip_laion2b_ft_320': _cfg( + hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768), + 'convnext_large_mlp.clip_laion2b_ft_soup_320': _cfg( + hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768), + 'convnext_xxlarge.clip_laion2b_soup': _cfg( + hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024), + 'convnext_xxlarge.clip_laion2b_rewind': _cfg( + hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024), +}) + + +@register_model +def convnext_atto(pretrained=False, **kwargs) -> ConvNeXt: + # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M + model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True) + model = _create_convnext('convnext_atto', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt: + # timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M + model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered') + model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_femto(pretrained=False, **kwargs) -> ConvNeXt: + # timm femto variant + model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True) + model = _create_convnext('convnext_femto', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_femto_ols(pretrained=False, **kwargs) -> ConvNeXt: + # timm femto variant + model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered') + model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_pico(pretrained=False, **kwargs) -> ConvNeXt: + # timm pico variant + model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True) + model = _create_convnext('convnext_pico', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_pico_ols(pretrained=False, **kwargs) -> ConvNeXt: + # timm nano variant with overlapping 3x3 conv stem + model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered') + model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_nano(pretrained=False, **kwargs) -> ConvNeXt: + # timm nano variant with standard stem and head + model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True) + model = _create_convnext('convnext_nano', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_nano_ols(pretrained=False, **kwargs) -> ConvNeXt: + # experimental nano variant with overlapping conv stem + model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap') + model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_tiny_hnf(pretrained=False, **kwargs) -> ConvNeXt: + # experimental tiny variant with norm before pooling in head (head norm first) + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True) + model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_tiny(pretrained=False, **kwargs) -> ConvNeXt: + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768)) + model = _create_convnext('convnext_tiny', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_small(pretrained=False, **kwargs) -> ConvNeXt: + model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768]) + model = _create_convnext('convnext_small', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_base(pretrained=False, **kwargs) -> ConvNeXt: + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024]) + model = _create_convnext('convnext_base', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_large(pretrained=False, **kwargs) -> ConvNeXt: + model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536]) + model = _create_convnext('convnext_large', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_large_mlp(pretrained=False, **kwargs) -> ConvNeXt: + model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536) + model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_xlarge(pretrained=False, **kwargs) -> ConvNeXt: + model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048]) + model = _create_convnext('convnext_xlarge', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_xxlarge(pretrained=False, **kwargs) -> ConvNeXt: + model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5)) + model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnextv2_atto(pretrained=False, **kwargs) -> ConvNeXt: + # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M + model_args = dict( + depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True) + model = _create_convnext('convnextv2_atto', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnextv2_femto(pretrained=False, **kwargs) -> ConvNeXt: + # timm femto variant + model_args = dict( + depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True) + model = _create_convnext('convnextv2_femto', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnextv2_pico(pretrained=False, **kwargs) -> ConvNeXt: + # timm pico variant + model_args = dict( + depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True) + model = _create_convnext('convnextv2_pico', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnextv2_nano(pretrained=False, **kwargs) -> ConvNeXt: + # timm nano variant with standard stem and head + model_args = dict( + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True) + model = _create_convnext('convnextv2_nano', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnextv2_tiny(pretrained=False, **kwargs) -> ConvNeXt: + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None) + model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnextv2_small(pretrained=False, **kwargs) -> ConvNeXt: + model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None) + model = _create_convnext('convnextv2_small', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnextv2_base(pretrained=False, **kwargs) -> ConvNeXt: + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None) + model = _create_convnext('convnextv2_base', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnextv2_large(pretrained=False, **kwargs) -> ConvNeXt: + model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None) + model = _create_convnext('convnextv2_large', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnextv2_huge(pretrained=False, **kwargs) -> ConvNeXt: + model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None) + model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +register_model_deprecations(__name__, { + 'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k', + 'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k', + 'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k', + 'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k', + 'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k', + 'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384', + 'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384', + 'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384', + 'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384', + 'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384', + 'convnext_tiny_in22k': 'convnext_tiny.fb_in22k', + 'convnext_small_in22k': 'convnext_small.fb_in22k', + 'convnext_base_in22k': 'convnext_base.fb_in22k', + 'convnext_large_in22k': 'convnext_large.fb_in22k', + 'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k', +}) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/davit.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/davit.py new file mode 100644 index 0000000000000000000000000000000000000000..d4d6ad690a03035f85800b1a15d66e4af6bff7ad --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/davit.py @@ -0,0 +1,689 @@ +""" DaViT: Dual Attention Vision Transformers + +As described in https://arxiv.org/abs/2204.03645 + +Input size invariant transformer architecture that combines channel and spacial +attention in each block. The attention mechanisms used are linear in complexity. + +DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below + +""" +# Copyright (c) 2022 Mingyu Ding +# All rights reserved. +# This source code is licensed under the MIT license +from functools import partial +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn +from timm.layers import NormMlpClassifierHead, ClassifierHead +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq +from ._registry import generate_default_cfgs, register_model + +__all__ = ['DaVit'] + + +class ConvPosEnc(nn.Module): + def __init__(self, dim: int, k: int = 3, act: bool = False): + super(ConvPosEnc, self).__init__() + + self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim) + self.act = nn.GELU() if act else nn.Identity() + + def forward(self, x: Tensor): + feat = self.proj(x) + x = x + self.act(feat) + return x + + +class Stem(nn.Module): + """ Size-agnostic implementation of 2D image to patch embedding, + allowing input size to be adjusted during model forward operation + """ + + def __init__( + self, + in_chs=3, + out_chs=96, + stride=4, + norm_layer=LayerNorm2d, + ): + super().__init__() + stride = to_2tuple(stride) + self.stride = stride + self.in_chs = in_chs + self.out_chs = out_chs + assert stride[0] == 4 # only setup for stride==4 + self.conv = nn.Conv2d( + in_chs, + out_chs, + kernel_size=7, + stride=stride, + padding=3, + ) + self.norm = norm_layer(out_chs) + + def forward(self, x: Tensor): + B, C, H, W = x.shape + x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1])) + x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0])) + x = self.conv(x) + x = self.norm(x) + return x + + +class Downsample(nn.Module): + def __init__( + self, + in_chs, + out_chs, + norm_layer=LayerNorm2d, + ): + super().__init__() + self.in_chs = in_chs + self.out_chs = out_chs + + self.norm = norm_layer(in_chs) + self.conv = nn.Conv2d( + in_chs, + out_chs, + kernel_size=2, + stride=2, + padding=0, + ) + + def forward(self, x: Tensor): + B, C, H, W = x.shape + x = self.norm(x) + x = F.pad(x, (0, (2 - W % 2) % 2)) + x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2)) + x = self.conv(x) + return x + + +class ChannelAttention(nn.Module): + + def __init__(self, dim, num_heads=8, qkv_bias=False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x: Tensor): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + k = k * self.scale + attention = k.transpose(-1, -2) @ v + attention = attention.softmax(dim=-1) + x = (attention @ q.transpose(-1, -2)).transpose(-1, -2) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + return x + + +class ChannelBlock(nn.Module): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ffn=True, + cpe_act=False, + ): + super().__init__() + + self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + self.ffn = ffn + self.norm1 = norm_layer(dim) + self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + + if self.ffn: + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + else: + self.norm2 = None + self.mlp = None + self.drop_path2 = None + + def forward(self, x: Tensor): + B, C, H, W = x.shape + + x = self.cpe1(x).flatten(2).transpose(1, 2) + + cur = self.norm1(x) + cur = self.attn(cur) + x = x + self.drop_path1(cur) + + x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)) + + if self.mlp is not None: + x = x.flatten(2).transpose(1, 2) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + x = x.transpose(1, 2).view(B, C, H, W) + + return x + + +def window_partition(x: Tensor, window_size: Tuple[int, int]): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + C = windows.shape[-1] + x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + """ + fused_attn: torch.jit.Final[bool] + + def __init__(self, dim, window_size, num_heads, qkv_bias=True): + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x: Tensor): + B_, N, C = x.shape + + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v) + else: + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = self.softmax(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + return x + + +class SpatialBlock(nn.Module): + r""" Windows Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ffn=True, + cpe_act=False, + ): + super().__init__() + self.dim = dim + self.ffn = ffn + self.num_heads = num_heads + self.window_size = to_2tuple(window_size) + self.mlp_ratio = mlp_ratio + + self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + self.window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + if self.ffn: + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + else: + self.norm2 = None + self.mlp = None + self.drop_path1 = None + + def forward(self, x: Tensor): + B, C, H, W = x.shape + + shortcut = self.cpe1(x).flatten(2).transpose(1, 2) + + x = self.norm1(shortcut) + x = x.view(B, H, W, C) + + pad_l = pad_t = 0 + pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + x_windows = window_partition(x, self.window_size) + x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C) + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) + x = window_reverse(attn_windows, self.window_size, Hp, Wp) + + # if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + x = shortcut + self.drop_path1(x) + + x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)) + + if self.mlp is not None: + x = x.flatten(2).transpose(1, 2) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + x = x.transpose(1, 2).view(B, C, H, W) + + return x + + +class DaVitStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + depth=1, + downsample=True, + attn_types=('spatial', 'channel'), + num_heads=3, + window_size=7, + mlp_ratio=4, + qkv_bias=True, + drop_path_rates=(0, 0), + norm_layer=LayerNorm2d, + norm_layer_cl=nn.LayerNorm, + ffn=True, + cpe_act=False + ): + super().__init__() + + self.grad_checkpointing = False + + # downsample embedding layer at the beginning of each stage + if downsample: + self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer) + else: + self.downsample = nn.Identity() + + ''' + repeating alternating attention blocks in each stage + default: (spatial -> channel) x depth + + potential opportunity to integrate with a more general version of ByobNet/ByoaNet + since the logic is similar + ''' + stage_blocks = [] + for block_idx in range(depth): + dual_attention_block = [] + for attn_idx, attn_type in enumerate(attn_types): + if attn_type == 'spatial': + dual_attention_block.append(SpatialBlock( + dim=out_chs, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=drop_path_rates[block_idx], + norm_layer=norm_layer_cl, + ffn=ffn, + cpe_act=cpe_act, + window_size=window_size, + )) + elif attn_type == 'channel': + dual_attention_block.append(ChannelBlock( + dim=out_chs, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=drop_path_rates[block_idx], + norm_layer=norm_layer_cl, + ffn=ffn, + cpe_act=cpe_act + )) + stage_blocks.append(nn.Sequential(*dual_attention_block)) + self.blocks = nn.Sequential(*stage_blocks) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + def forward(self, x: Tensor): + x = self.downsample(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class DaVit(nn.Module): + r""" DaViT + A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645 + Supports arbitrary input sizes and pyramid feature extraction + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1) + embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768) + num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24) + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + """ + + def __init__( + self, + in_chans=3, + depths=(1, 1, 3, 1), + embed_dims=(96, 192, 384, 768), + num_heads=(3, 6, 12, 24), + window_size=7, + mlp_ratio=4, + qkv_bias=True, + norm_layer='layernorm2d', + norm_layer_cl='layernorm', + norm_eps=1e-5, + attn_types=('spatial', 'channel'), + ffn=True, + cpe_act=False, + drop_rate=0., + drop_path_rate=0., + num_classes=1000, + global_pool='avg', + head_norm_first=False, + ): + super().__init__() + num_stages = len(embed_dims) + assert num_stages == len(num_heads) == len(depths) + norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) + norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps) + self.num_classes = num_classes + self.num_features = embed_dims[-1] + self.drop_rate = drop_rate + self.grad_checkpointing = False + self.feature_info = [] + + self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer) + in_chs = embed_dims[0] + + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + stages = [] + for stage_idx in range(num_stages): + out_chs = embed_dims[stage_idx] + stage = DaVitStage( + in_chs, + out_chs, + depth=depths[stage_idx], + downsample=stage_idx > 0, + attn_types=attn_types, + num_heads=num_heads[stage_idx], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path_rates=dpr[stage_idx], + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl, + ffn=ffn, + cpe_act=cpe_act, + ) + in_chs = out_chs + stages.append(stage) + self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')] + + self.stages = nn.Sequential(*stages) + + # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets + # otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt + # FIXME generalize this structure to ClassifierHead + if head_norm_first: + self.norm_pre = norm_layer(self.num_features) + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + ) + else: + self.norm_pre = nn.Identity() + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + norm_layer=norm_layer, + ) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', # stem and embed + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^norm_pre', (99999,)), + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + for stage in self.stages: + stage.set_grad_checkpointing(enable=enable) + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.head.reset(num_classes, global_pool) + + def forward_features(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + x = self.norm_pre(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=True) if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + """ Remap MSFT checkpoints -> timm """ + if 'head.fc.weight' in state_dict: + return state_dict # non-MSFT checkpoint + + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + + import re + out_dict = {} + for k, v in state_dict.items(): + k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.downsample', k) + k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k) + k = k.replace('downsample.proj', 'downsample.conv') + k = k.replace('stages.0.downsample', 'stem') + k = k.replace('head.', 'head.fc.') + k = k.replace('norms.', 'head.norm.') + k = k.replace('cpe.0', 'cpe1') + k = k.replace('cpe.1', 'cpe2') + out_dict[k] = v + return out_dict + + +def _create_davit(variant, pretrained=False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + + model = build_model_with_cfg( + DaVit, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs) + + return model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.95, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + **kwargs + } + + +# TODO contact authors to get larger pretrained models +default_cfgs = generate_default_cfgs({ + # official microsoft weights from https://github.com/dingmyu/davit + 'davit_tiny.msft_in1k': _cfg( + hf_hub_id='timm/'), + 'davit_small.msft_in1k': _cfg( + hf_hub_id='timm/'), + 'davit_base.msft_in1k': _cfg( + hf_hub_id='timm/'), + 'davit_large': _cfg(), + 'davit_huge': _cfg(), + 'davit_giant': _cfg(), +}) + + +@register_model +def davit_tiny(pretrained=False, **kwargs) -> DaVit: + model_args = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24)) + return _create_davit('davit_tiny', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def davit_small(pretrained=False, **kwargs) -> DaVit: + model_args = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24)) + return _create_davit('davit_small', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def davit_base(pretrained=False, **kwargs) -> DaVit: + model_args = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32)) + return _create_davit('davit_base', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def davit_large(pretrained=False, **kwargs) -> DaVit: + model_args = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), num_heads=(6, 12, 24, 48)) + return _create_davit('davit_large', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def davit_huge(pretrained=False, **kwargs) -> DaVit: + model_args = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64)) + return _create_davit('davit_huge', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def davit_giant(pretrained=False, **kwargs) -> DaVit: + model_args = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96)) + return _create_davit('davit_giant', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/densenet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/densenet.py new file mode 100644 index 0000000000000000000000000000000000000000..ade61a14df89d49f83f1d55e4f998e84113ce787 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/densenet.py @@ -0,0 +1,421 @@ +"""Pytorch Densenet implementation w/ tweaks +This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with +fixed kwargs passthrough and addition of dynamic global avg/max pool. +""" +import re +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from torch.jit.annotations import List + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import MATCH_PREV_GROUP +from ._registry import register_model, generate_default_cfgs, register_model_deprecations + +__all__ = ['DenseNet'] + + +class DenseLayer(nn.Module): + def __init__( + self, + num_input_features, + growth_rate, + bn_size, + norm_layer=BatchNormAct2d, + drop_rate=0., + grad_checkpointing=False, + ): + super(DenseLayer, self).__init__() + self.add_module('norm1', norm_layer(num_input_features)), + self.add_module('conv1', nn.Conv2d( + num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)), + self.add_module('norm2', norm_layer(bn_size * growth_rate)), + self.add_module('conv2', nn.Conv2d( + bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)), + self.drop_rate = float(drop_rate) + self.grad_checkpointing = grad_checkpointing + + def bottleneck_fn(self, xs): + # type: (List[torch.Tensor]) -> torch.Tensor + concated_features = torch.cat(xs, 1) + bottleneck_output = self.conv1(self.norm1(concated_features)) # noqa: T484 + return bottleneck_output + + # todo: rewrite when torchscript supports any + def any_requires_grad(self, x): + # type: (List[torch.Tensor]) -> bool + for tensor in x: + if tensor.requires_grad: + return True + return False + + @torch.jit.unused # noqa: T484 + def call_checkpoint_bottleneck(self, x): + # type: (List[torch.Tensor]) -> torch.Tensor + def closure(*xs): + return self.bottleneck_fn(xs) + + return cp.checkpoint(closure, *x) + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (List[torch.Tensor]) -> (torch.Tensor) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> (torch.Tensor) + pass + + # torchscript does not yet support *args, so we overload method + # allowing it to take either a List[Tensor] or single Tensor + def forward(self, x): # noqa: F811 + if isinstance(x, torch.Tensor): + prev_features = [x] + else: + prev_features = x + + if self.grad_checkpointing and self.any_requires_grad(prev_features): + if torch.jit.is_scripting(): + raise Exception("Memory Efficient not supported in JIT") + bottleneck_output = self.call_checkpoint_bottleneck(prev_features) + else: + bottleneck_output = self.bottleneck_fn(prev_features) + + new_features = self.conv2(self.norm2(bottleneck_output)) + if self.drop_rate > 0: + new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) + return new_features + + +class DenseBlock(nn.ModuleDict): + _version = 2 + + def __init__( + self, + num_layers, + num_input_features, + bn_size, + growth_rate, + norm_layer=BatchNormAct2d, + drop_rate=0., + grad_checkpointing=False, + ): + super(DenseBlock, self).__init__() + for i in range(num_layers): + layer = DenseLayer( + num_input_features + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + norm_layer=norm_layer, + drop_rate=drop_rate, + grad_checkpointing=grad_checkpointing, + ) + self.add_module('denselayer%d' % (i + 1), layer) + + def forward(self, init_features): + features = [init_features] + for name, layer in self.items(): + new_features = layer(features) + features.append(new_features) + return torch.cat(features, 1) + + +class DenseTransition(nn.Sequential): + def __init__( + self, + num_input_features, + num_output_features, + norm_layer=BatchNormAct2d, + aa_layer=None, + ): + super(DenseTransition, self).__init__() + self.add_module('norm', norm_layer(num_input_features)) + self.add_module('conv', nn.Conv2d( + num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) + if aa_layer is not None: + self.add_module('pool', aa_layer(num_output_features, stride=2)) + else: + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + + +class DenseNet(nn.Module): + r"""Densenet-BC model class, based on + `"Densely Connected Convolutional Networks" `_ + + Args: + growth_rate (int) - how many filters to add each layer (`k` in paper) + block_config (list of 4 ints) - how many layers in each pooling block + bn_size (int) - multiplicative factor for number of bottle neck layers + (i.e. bn_size * k features in the bottleneck layer) + drop_rate (float) - dropout rate before classifier layer + proj_drop_rate (float) - dropout rate after each dense layer + num_classes (int) - number of classification classes + memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, + but slower. Default: *False*. See `"paper" `_ + """ + + def __init__( + self, + growth_rate=32, + block_config=(6, 12, 24, 16), + num_classes=1000, + in_chans=3, + global_pool='avg', + bn_size=4, + stem_type='', + act_layer='relu', + norm_layer='batchnorm2d', + aa_layer=None, + drop_rate=0., + proj_drop_rate=0., + memory_efficient=False, + aa_stem_only=True, + ): + self.num_classes = num_classes + super(DenseNet, self).__init__() + norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer) + + # Stem + deep_stem = 'deep' in stem_type # 3x3 deep stem + num_init_features = growth_rate * 2 + if aa_layer is None: + stem_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + else: + stem_pool = nn.Sequential(*[ + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + aa_layer(channels=num_init_features, stride=2)]) + if deep_stem: + stem_chs_1 = stem_chs_2 = growth_rate + if 'tiered' in stem_type: + stem_chs_1 = 3 * (growth_rate // 4) + stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4) + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)), + ('norm0', norm_layer(stem_chs_1)), + ('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)), + ('norm1', norm_layer(stem_chs_2)), + ('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)), + ('norm2', norm_layer(num_init_features)), + ('pool0', stem_pool), + ])) + else: + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), + ('norm0', norm_layer(num_init_features)), + ('pool0', stem_pool), + ])) + self.feature_info = [ + dict(num_chs=num_init_features, reduction=2, module=f'features.norm{2 if deep_stem else 0}')] + current_stride = 4 + + # DenseBlocks + num_features = num_init_features + for i, num_layers in enumerate(block_config): + block = DenseBlock( + num_layers=num_layers, + num_input_features=num_features, + bn_size=bn_size, + growth_rate=growth_rate, + norm_layer=norm_layer, + drop_rate=proj_drop_rate, + grad_checkpointing=memory_efficient, + ) + module_name = f'denseblock{(i + 1)}' + self.features.add_module(module_name, block) + num_features = num_features + num_layers * growth_rate + transition_aa_layer = None if aa_stem_only else aa_layer + if i != len(block_config) - 1: + self.feature_info += [ + dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)] + current_stride *= 2 + trans = DenseTransition( + num_input_features=num_features, + num_output_features=num_features // 2, + norm_layer=norm_layer, + aa_layer=transition_aa_layer, + ) + self.features.add_module(f'transition{i + 1}', trans) + num_features = num_features // 2 + + # Final batch norm + self.features.add_module('norm5', norm_layer(num_features)) + + self.feature_info += [dict(num_chs=num_features, reduction=current_stride, module='features.norm5')] + self.num_features = num_features + + # Linear layer + global_pool, classifier = create_classifier( + self.num_features, + self.num_classes, + pool_type=global_pool, + ) + self.global_pool = global_pool + self.head_drop = nn.Dropout(drop_rate) + self.classifier = classifier + + # Official init from torch repo. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^features\.conv[012]|features\.norm[012]|features\.pool[012]', + blocks=r'^features\.(?:denseblock|transition)(\d+)' if coarse else [ + (r'^features\.denseblock(\d+)\.denselayer(\d+)', None), + (r'^features\.transition(\d+)', MATCH_PREV_GROUP) # FIXME combine with previous denselayer + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for b in self.features.modules(): + if isinstance(b, DenseLayer): + b.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + return self.features(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + x = self.head_drop(x) + x = self.classifier(x) + return x + + +def _filter_torchvision_pretrained(state_dict): + pattern = re.compile( + r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + return state_dict + + +def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs): + kwargs['growth_rate'] = growth_rate + kwargs['block_config'] = block_config + return build_model_with_cfg( + DenseNet, + variant, + pretrained, + feature_cfg=dict(flatten_sequential=True), + pretrained_filter_fn=_filter_torchvision_pretrained, + **kwargs, + ) + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'features.conv0', 'classifier': 'classifier', **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + 'densenet121.ra_in1k': _cfg( + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'densenetblur121d.ra_in1k': _cfg( + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'densenet264d.untrained': _cfg(), + 'densenet121.tv_in1k': _cfg(hf_hub_id='timm/'), + 'densenet169.tv_in1k': _cfg(hf_hub_id='timm/'), + 'densenet201.tv_in1k': _cfg(hf_hub_id='timm/'), + 'densenet161.tv_in1k': _cfg(hf_hub_id='timm/'), +}) + + +@register_model +def densenet121(pretrained=False, **kwargs) -> DenseNet: + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + model_args = dict(growth_rate=32, block_config=(6, 12, 24, 16)) + model = _create_densenet('densenet121', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def densenetblur121d(pretrained=False, **kwargs) -> DenseNet: + r"""Densenet-121 w/ blur-pooling & 3-layer 3x3 stem + `"Densely Connected Convolutional Networks" ` + """ + model_args = dict(growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', aa_layer=BlurPool2d) + model = _create_densenet('densenetblur121d', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def densenet169(pretrained=False, **kwargs) -> DenseNet: + r"""Densenet-169 model from + `"Densely Connected Convolutional Networks" ` + """ + model_args = dict(growth_rate=32, block_config=(6, 12, 32, 32)) + model = _create_densenet('densenet169', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def densenet201(pretrained=False, **kwargs) -> DenseNet: + r"""Densenet-201 model from + `"Densely Connected Convolutional Networks" ` + """ + model_args = dict(growth_rate=32, block_config=(6, 12, 48, 32)) + model = _create_densenet('densenet201', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def densenet161(pretrained=False, **kwargs) -> DenseNet: + r"""Densenet-161 model from + `"Densely Connected Convolutional Networks" ` + """ + model_args = dict(growth_rate=48, block_config=(6, 12, 36, 24)) + model = _create_densenet('densenet161', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def densenet264d(pretrained=False, **kwargs) -> DenseNet: + r"""Densenet-264 model from + `"Densely Connected Convolutional Networks" ` + """ + model_args = dict(growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep') + model = _create_densenet('densenet264d', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +register_model_deprecations(__name__, { + 'tv_densenet121': 'densenet121.tv_in1k', +}) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/edgenext.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/edgenext.py new file mode 100644 index 0000000000000000000000000000000000000000..661669d5edcf7fd19767d72952d21f7cb15b688b --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/edgenext.py @@ -0,0 +1,576 @@ +""" EdgeNeXt + +Paper: `EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications` + - https://arxiv.org/abs/2206.10589 + +Original code and weights from https://github.com/mmaaz60/EdgeNeXt + +Modifications and additions for timm by / Copyright 2022, Ross Wightman +""" +import math +from collections import OrderedDict +from functools import partial +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d, \ + use_fused_attn, NormMlpClassifierHead, ClassifierHead +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model, generate_default_cfgs + +__all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this + + +@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method +class PositionalEncodingFourier(nn.Module): + def __init__(self, hidden_dim=32, dim=768, temperature=10000): + super().__init__() + self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) + self.scale = 2 * math.pi + self.temperature = temperature + self.hidden_dim = hidden_dim + self.dim = dim + + def forward(self, shape: Tuple[int, int, int]): + device = self.token_projection.weight.device + dtype = self.token_projection.weight.dtype + inv_mask = ~torch.zeros(shape).to(device=device, dtype=torch.bool) + y_embed = inv_mask.cumsum(1, dtype=torch.float32) + x_embed = inv_mask.cumsum(2, dtype=torch.float32) + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.hidden_dim, dtype=torch.int64, device=device).to(torch.float32) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), + pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), + pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = self.token_projection(pos.to(dtype)) + + return pos + + +class ConvBlock(nn.Module): + def __init__( + self, + dim, + dim_out=None, + kernel_size=7, + stride=1, + conv_bias=True, + expand_ratio=4, + ls_init_value=1e-6, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, drop_path=0., + ): + super().__init__() + dim_out = dim_out or dim + self.shortcut_after_dw = stride > 1 or dim != dim_out + + self.conv_dw = create_conv2d( + dim, dim_out, kernel_size=kernel_size, stride=stride, depthwise=True, bias=conv_bias) + self.norm = norm_layer(dim_out) + self.mlp = Mlp(dim_out, int(expand_ratio * dim_out), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.conv_dw(x) + if self.shortcut_after_dw: + shortcut = x + + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.mlp(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = shortcut + self.drop_path(x) + return x + + +class CrossCovarianceAttn(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 4, 1) + q, k, v = qkv.unbind(0) + + # NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v) + + x = x.permute(0, 3, 1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @torch.jit.ignore + def no_weight_decay(self): + return {'temperature'} + + +class SplitTransposeBlock(nn.Module): + def __init__( + self, + dim, + num_scales=1, + num_heads=8, + expand_ratio=4, + use_pos_emb=True, + conv_bias=True, + qkv_bias=True, + ls_init_value=1e-6, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + drop_path=0., + attn_drop=0., + proj_drop=0. + ): + super().__init__() + width = max(int(math.ceil(dim / num_scales)), int(math.floor(dim // num_scales))) + self.width = width + self.num_scales = max(1, num_scales - 1) + + convs = [] + for i in range(self.num_scales): + convs.append(create_conv2d(width, width, kernel_size=3, depthwise=True, bias=conv_bias)) + self.convs = nn.ModuleList(convs) + + self.pos_embd = None + if use_pos_emb: + self.pos_embd = PositionalEncodingFourier(dim=dim) + self.norm_xca = norm_layer(dim) + self.gamma_xca = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.xca = CrossCovarianceAttn( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) + + self.norm = norm_layer(dim, eps=1e-6) + self.mlp = Mlp(dim, int(expand_ratio * dim), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + # scales code re-written for torchscript as per my res2net fixes -rw + # NOTE torch.split(x, self.width, 1) causing issues with ONNX export + spx = x.chunk(len(self.convs) + 1, dim=1) + spo = [] + sp = spx[0] + for i, conv in enumerate(self.convs): + if i > 0: + sp = sp + spx[i] + sp = conv(sp) + spo.append(sp) + spo.append(spx[-1]) + x = torch.cat(spo, 1) + + # XCA + B, C, H, W = x.shape + x = x.reshape(B, C, H * W).permute(0, 2, 1) + if self.pos_embd is not None: + pos_encoding = self.pos_embd((B, H, W)).reshape(B, -1, x.shape[1]).permute(0, 2, 1) + x = x + pos_encoding + x = x + self.drop_path(self.gamma_xca * self.xca(self.norm_xca(x))) + x = x.reshape(B, H, W, C) + + # Inverted Bottleneck + x = self.norm(x) + x = self.mlp(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = shortcut + self.drop_path(x) + return x + + +class EdgeNeXtStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + stride=2, + depth=2, + num_global_blocks=1, + num_heads=4, + scales=2, + kernel_size=7, + expand_ratio=4, + use_pos_emb=False, + downsample_block=False, + conv_bias=True, + ls_init_value=1.0, + drop_path_rates=None, + norm_layer=LayerNorm2d, + norm_layer_cl=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU + ): + super().__init__() + self.grad_checkpointing = False + + if downsample_block or stride == 1: + self.downsample = nn.Identity() + else: + self.downsample = nn.Sequential( + norm_layer(in_chs), + nn.Conv2d(in_chs, out_chs, kernel_size=2, stride=2, bias=conv_bias) + ) + in_chs = out_chs + + stage_blocks = [] + for i in range(depth): + if i < depth - num_global_blocks: + stage_blocks.append( + ConvBlock( + dim=in_chs, + dim_out=out_chs, + stride=stride if downsample_block and i == 0 else 1, + conv_bias=conv_bias, + kernel_size=kernel_size, + expand_ratio=expand_ratio, + ls_init_value=ls_init_value, + drop_path=drop_path_rates[i], + norm_layer=norm_layer_cl, + act_layer=act_layer, + ) + ) + else: + stage_blocks.append( + SplitTransposeBlock( + dim=in_chs, + num_scales=scales, + num_heads=num_heads, + expand_ratio=expand_ratio, + use_pos_emb=use_pos_emb, + conv_bias=conv_bias, + ls_init_value=ls_init_value, + drop_path=drop_path_rates[i], + norm_layer=norm_layer_cl, + act_layer=act_layer, + ) + ) + in_chs = out_chs + self.blocks = nn.Sequential(*stage_blocks) + + def forward(self, x): + x = self.downsample(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class EdgeNeXt(nn.Module): + def __init__( + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + dims=(24, 48, 88, 168), + depths=(3, 3, 9, 3), + global_block_counts=(0, 1, 1, 1), + kernel_sizes=(3, 5, 7, 9), + heads=(8, 8, 8, 8), + d2_scales=(2, 2, 3, 4), + use_pos_emb=(False, True, False, False), + ls_init_value=1e-6, + head_init_scale=1., + expand_ratio=4, + downsample_block=False, + conv_bias=True, + stem_type='patch', + head_norm_first=False, + act_layer=nn.GELU, + drop_path_rate=0., + drop_rate=0., + ): + super().__init__() + self.num_classes = num_classes + self.global_pool = global_pool + self.drop_rate = drop_rate + norm_layer = partial(LayerNorm2d, eps=1e-6) + norm_layer_cl = partial(nn.LayerNorm, eps=1e-6) + self.feature_info = [] + + assert stem_type in ('patch', 'overlap') + if stem_type == 'patch': + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, bias=conv_bias), + norm_layer(dims[0]), + ) + else: + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=9, stride=4, padding=9 // 2, bias=conv_bias), + norm_layer(dims[0]), + ) + + curr_stride = 4 + stages = [] + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + in_chs = dims[0] + for i in range(4): + stride = 2 if curr_stride == 2 or i > 0 else 1 + # FIXME support dilation / output_stride + curr_stride *= stride + stages.append(EdgeNeXtStage( + in_chs=in_chs, + out_chs=dims[i], + stride=stride, + depth=depths[i], + num_global_blocks=global_block_counts[i], + num_heads=heads[i], + drop_path_rates=dp_rates[i], + scales=d2_scales[i], + expand_ratio=expand_ratio, + kernel_size=kernel_sizes[i], + use_pos_emb=use_pos_emb[i], + ls_init_value=ls_init_value, + downsample_block=downsample_block, + conv_bias=conv_bias, + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl, + act_layer=act_layer, + )) + # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 + in_chs = dims[i] + self.feature_info += [dict(num_chs=in_chs, reduction=curr_stride, module=f'stages.{i}')] + + self.stages = nn.Sequential(*stages) + + self.num_features = dims[-1] + if head_norm_first: + self.norm_pre = norm_layer(self.num_features) + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + ) + else: + self.norm_pre = nn.Identity() + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + norm_layer=norm_layer, + ) + + named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.downsample', (0,)), # blocks + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^norm_pre', (99999,)) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes=0, global_pool=None): + self.head.reset(num_classes, global_pool) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm_pre(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=True) if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _init_weights(module, name=None, head_init_scale=1.0): + if isinstance(module, nn.Conv2d): + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + trunc_normal_tf_(module.weight, std=.02) + nn.init.zeros_(module.bias) + if name and 'head.' in name: + module.weight.data.mul_(head_init_scale) + module.bias.data.mul_(head_init_scale) + + +def checkpoint_filter_fn(state_dict, model): + """ Remap FB checkpoints -> timm """ + if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict: + return state_dict # non-FB checkpoint + + # models were released as train checkpoints... :/ + if 'model_ema' in state_dict: + state_dict = state_dict['model_ema'] + elif 'model' in state_dict: + state_dict = state_dict['model'] + elif 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + + out_dict = {} + import re + for k, v in state_dict.items(): + k = k.replace('downsample_layers.0.', 'stem.') + k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) + k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k) + k = k.replace('dwconv', 'conv_dw') + k = k.replace('pwconv', 'mlp.fc') + k = k.replace('head.', 'head.fc.') + if k.startswith('norm.'): + k = k.replace('norm', 'head.norm') + if v.ndim == 2 and 'head' not in k: + model_shape = model.state_dict()[k].shape + v = v.reshape(model_shape) + out_dict[k] = v + return out_dict + + +def _create_edgenext(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + EdgeNeXt, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **kwargs) + return model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'edgenext_xx_small.in1k': _cfg( + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'edgenext_x_small.in1k': _cfg( + hf_hub_id='timm/', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'edgenext_small.usi_in1k': _cfg( # USI weights + hf_hub_id='timm/', + crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, + ), + 'edgenext_base.usi_in1k': _cfg( # USI weights + hf_hub_id='timm/', + crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, + ), + 'edgenext_base.in21k_ft_in1k': _cfg( # USI weights + hf_hub_id='timm/', + crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, + ), + 'edgenext_small_rw.sw_in1k': _cfg( + hf_hub_id='timm/', + test_input_size=(3, 320, 320), test_crop_pct=1.0, + ), +}) + + +@register_model +def edgenext_xx_small(pretrained=False, **kwargs) -> EdgeNeXt: + # 1.33M & 260.58M @ 256 resolution + # 71.23% Top-1 accuracy + # No AA, Color Jitter=0.4, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=51.66 versus 47.67 for MobileViT_XXS + # For A100: FPS @ BS=1: 212.13 & @ BS=256: 7042.06 versus FPS @ BS=1: 96.68 & @ BS=256: 4624.71 for MobileViT_XXS + model_args = dict(depths=(2, 2, 6, 2), dims=(24, 48, 88, 168), heads=(4, 4, 4, 4)) + return _create_edgenext('edgenext_xx_small', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def edgenext_x_small(pretrained=False, **kwargs) -> EdgeNeXt: + # 2.34M & 538.0M @ 256 resolution + # 75.00% Top-1 accuracy + # No AA, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=31.61 versus 28.49 for MobileViT_XS + # For A100: FPS @ BS=1: 179.55 & @ BS=256: 4404.95 versus FPS @ BS=1: 94.55 & @ BS=256: 2361.53 for MobileViT_XS + model_args = dict(depths=(3, 3, 9, 3), dims=(32, 64, 100, 192), heads=(4, 4, 4, 4)) + return _create_edgenext('edgenext_x_small', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def edgenext_small(pretrained=False, **kwargs) -> EdgeNeXt: + # 5.59M & 1260.59M @ 256 resolution + # 79.43% Top-1 accuracy + # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=20.47 versus 18.86 for MobileViT_S + # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S + model_args = dict(depths=(3, 3, 9, 3), dims=(48, 96, 160, 304)) + return _create_edgenext('edgenext_small', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def edgenext_base(pretrained=False, **kwargs) -> EdgeNeXt: + # 18.51M & 3840.93M @ 256 resolution + # 82.5% (normal) 83.7% (USI) Top-1 accuracy + # AA=True, Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=xx.xx versus xx.xx for MobileViT_S + # For A100: FPS @ BS=1: xxx.xx & @ BS=256: xxxx.xx + model_args = dict(depths=[3, 3, 9, 3], dims=[80, 160, 288, 584]) + return _create_edgenext('edgenext_base', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def edgenext_small_rw(pretrained=False, **kwargs) -> EdgeNeXt: + model_args = dict( + depths=(3, 3, 9, 3), dims=(48, 96, 192, 384), + downsample_block=True, conv_bias=False, stem_type='overlap') + return _create_edgenext('edgenext_small_rw', pretrained=pretrained, **dict(model_args, **kwargs)) + diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/efficientnet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6e61d1bfde6bb9a39aa0eadc463feef95364b94b --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/efficientnet.py @@ -0,0 +1,2342 @@ +""" The EfficientNet Family in PyTorch + +An implementation of EfficienNet that covers variety of related models with efficient architectures: + +* EfficientNet-V2 + - `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + +* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent weight ports) + - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946 + - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971 + - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665 + - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252 + +* MixNet (Small, Medium, and Large) + - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595 + +* MNasNet B1, A1 (SE), Small + - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626 + +* FBNet-C + - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443 + +* Single-Path NAS Pixel1 + - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877 + +* TinyNet + - Model Rubik's Cube: Twisting Resolution, Depth and Width for TinyNets - https://arxiv.org/abs/2010.14819 + - Definitions & weights borrowed from https://github.com/huawei-noah/CV-Backbones/tree/master/tinynet_pytorch + +* And likely more... + +The majority of the above models (EfficientNet*, MixNet, MnasNet) and original weights were made available +by Mingxing Tan, Quoc Le, and other members of their Google Brain team. Thanks for consistently releasing +the models and weights open source! + +Hacked together by / Copyright 2019, Ross Wightman +""" +from functools import partial +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite +from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ + round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT +from ._features import FeatureInfo, FeatureHooks +from ._manipulate import checkpoint_seq +from ._registry import generate_default_cfgs, register_model, register_model_deprecations + +__all__ = ['EfficientNet', 'EfficientNetFeatures'] + + +class EfficientNet(nn.Module): + """ EfficientNet + + A flexible and performant PyTorch implementation of efficient network architectures, including: + * EfficientNet-V2 Small, Medium, Large, XL & B0-B3 + * EfficientNet B0-B8, L2 + * EfficientNet-EdgeTPU + * EfficientNet-CondConv + * MixNet S, M, L, XL + * MnasNet A1, B1, and small + * MobileNet-V2 + * FBNet C + * Single-Path NAS Pixel1 + * TinyNet + """ + + def __init__( + self, + block_args, + num_classes=1000, + num_features=1280, + in_chans=3, + stem_size=32, + fix_stem=False, + output_stride=32, + pad_type='', + round_chs_fn=round_channels, + act_layer=None, + norm_layer=None, + se_layer=None, + drop_rate=0., + drop_path_rate=0., + global_pool='avg' + ): + super(EfficientNet, self).__init__() + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + se_layer = se_layer or SqueezeExcite + self.num_classes = num_classes + self.num_features = num_features + self.drop_rate = drop_rate + self.grad_checkpointing = False + + # Stem + if not fix_stem: + stem_size = round_chs_fn(stem_size) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_act_layer(stem_size, inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + output_stride=output_stride, + pad_type=pad_type, + round_chs_fn=round_chs_fn, + act_layer=act_layer, + norm_layer=norm_layer, + se_layer=se_layer, + drop_path_rate=drop_path_rate, + ) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = builder.features + head_chs = builder.in_chs + + # Head + Pooling + self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type) + self.bn2 = norm_act_layer(self.num_features, inplace=True) + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + efficientnet_init_weights(self) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1] + layers.extend(self.blocks) + layers.extend([self.conv_head, self.bn2, self.global_pool]) + layers.extend([nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^conv_stem|bn1', + blocks=[ + (r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None), + (r'conv_head|bn2', (99999,)) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x, flatten=True) + else: + x = self.blocks(x) + x = self.conv_head(x) + x = self.bn2(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return x if pre_logits else self.classifier(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +class EfficientNetFeatures(nn.Module): + """ EfficientNet Feature Extractor + + A work-in-progress feature extraction module for EfficientNet, to use as a backbone for segmentation + and object detection models. + """ + + def __init__( + self, + block_args, + out_indices=(0, 1, 2, 3, 4), + feature_location='bottleneck', + in_chans=3, + stem_size=32, + fix_stem=False, + output_stride=32, + pad_type='', + round_chs_fn=round_channels, + act_layer=None, + norm_layer=None, + se_layer=None, + drop_rate=0., + drop_path_rate=0. + ): + super(EfficientNetFeatures, self).__init__() + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + se_layer = se_layer or SqueezeExcite + self.drop_rate = drop_rate + self.grad_checkpointing = False + + # Stem + if not fix_stem: + stem_size = round_chs_fn(stem_size) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_act_layer(stem_size, inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + output_stride=output_stride, + pad_type=pad_type, + round_chs_fn=round_chs_fn, + act_layer=act_layer, + norm_layer=norm_layer, + se_layer=se_layer, + drop_path_rate=drop_path_rate, + feature_location=feature_location, + ) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = FeatureInfo(builder.features, out_indices) + self._stage_out_idx = {f['stage']: f['index'] for f in self.feature_info.get_dicts()} + + efficientnet_init_weights(self) + + # Register feature extraction hooks with FeatureHooks helper + self.feature_hooks = None + if feature_location != 'bottleneck': + hooks = self.feature_info.get_dicts(keys=('module', 'hook_type')) + self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + def forward(self, x) -> List[torch.Tensor]: + x = self.conv_stem(x) + x = self.bn1(x) + if self.feature_hooks is None: + features = [] + if 0 in self._stage_out_idx: + features.append(x) # add stem out + for i, b in enumerate(self.blocks): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(b, x) + else: + x = b(x) + if i + 1 in self._stage_out_idx: + features.append(x) + return features + else: + self.blocks(x) + out = self.feature_hooks.get_output(x.device) + return list(out.values()) + + +def _create_effnet(variant, pretrained=False, **kwargs): + features_mode = '' + model_cls = EfficientNet + kwargs_filter = None + if kwargs.pop('features_only', False): + if 'feature_cfg' in kwargs: + features_mode = 'cfg' + else: + kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool') + model_cls = EfficientNetFeatures + features_mode = 'cls' + + model = build_model_with_cfg( + model_cls, + variant, + pretrained, + features_only=features_mode == 'cfg', + pretrained_strict=features_mode != 'cls', + kwargs_filter=kwargs_filter, + **kwargs, + ) + if features_mode == 'cls': + model.pretrained_cfg = model.default_cfg = pretrained_cfg_for_features(model.pretrained_cfg) + return model + + +def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-a1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r2_k3_s2_e6_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25'], + # stage 3, 28x28 in + ['ir_r4_k3_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r3_k5_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + ['ds_r1_k3_s1_c8'], + ['ir_r1_k3_s2_e3_c16'], + ['ir_r2_k3_s2_e6_c16'], + ['ir_r4_k5_s2_e6_c32_se0.25'], + ['ir_r3_k3_s1_e6_c32_se0.25'], + ['ir_r3_k5_s2_e6_c88_se0.25'], + ['ir_r1_k3_s1_e6_c144'] + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=8, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_mobilenet_v2( + variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs): + """ Generate MobileNet-V2 network + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py + Paper: https://arxiv.org/abs/1801.04381 + """ + arch_def = [ + ['ds_r1_k3_s1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r3_k3_s2_e6_c32'], + ['ir_r4_k3_s2_e6_c64'], + ['ir_r3_k3_s1_e6_c96'], + ['ir_r3_k3_s2_e6_c160'], + ['ir_r1_k3_s1_e6_c320'], + ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head), + num_features=1280 if fix_stem_head else max(1280, round_chs_fn(1280)), + stem_size=32, + fix_stem=fix_stem_head, + round_chs_fn=round_chs_fn, + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'relu6'), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """ FBNet-C + + Paper: https://arxiv.org/abs/1812.03443 + Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py + + NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper, + it was used to confirm some building block details + """ + arch_def = [ + ['ir_r1_k3_s1_e1_c16'], + ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'], + ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'], + ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'], + ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'], + ['ir_r4_k5_s2_e6_c184'], + ['ir_r1_k3_s1_e6_c352'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=16, + num_features=1984, # paper suggests this, but is not 100% clear + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates the Single-Path NAS model from search targeted for Pixel1 phone. + + Paper: https://arxiv.org/abs/1904.02877 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'], + # stage 4, 14x14in + ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnet( + variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8, + group_size=None, pretrained=False, **kwargs): + """Creates an EfficientNet model. + + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier, divisor=channel_divisor) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size), + num_features=round_chs_fn(1280), + stem_size=32, + round_chs_fn=round_chs_fn, + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnet_edge( + variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs): + """ Creates an EfficientNet-EdgeTPU model + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu + """ + + arch_def = [ + # NOTE `fc` is present to override a mismatch between stem channels and in chs not + # present in other models + ['er_r1_k3_s1_e4_c24_fc24_noskip'], + ['er_r2_k3_s2_e8_c32'], + ['er_r4_k3_s2_e8_c48'], + ['ir_r5_k5_s2_e8_c96'], + ['ir_r4_k5_s1_e8_c144'], + ['ir_r2_k5_s2_e8_c192'], + ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size), + num_features=round_chs_fn(1280), + stem_size=32, + round_chs_fn=round_chs_fn, + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'relu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnet_condconv( + variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs): + """Creates an EfficientNet-CondConv model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25_cc4'], + ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], + ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], + ] + # NOTE unlike official impl, this one uses `cc` option where x is the base number of experts for each stage and + # the expert_multiplier increases that on a per-model basis as with depth/channel multipliers + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), + num_features=round_chs_fn(1280), + stem_size=32, + round_chs_fn=round_chs_fn, + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'swish'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet-Lite model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), + 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), + 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), + 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), + 'efficientnet-lite4': (1.4, 1.8, 300, 0.3), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r2_k5_s2_e6_c40'], + ['ir_r3_k3_s2_e6_c80'], + ['ir_r3_k5_s1_e6_c112'], + ['ir_r4_k5_s2_e6_c192'], + ['ir_r1_k3_s1_e6_c320'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True), + num_features=1280, + stem_size=32, + fix_stem=True, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + act_layer=resolve_act_layer(kwargs, 'relu6'), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_base( + variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 base model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + """ + arch_def = [ + ['cn_r1_k3_s1_e1_c16_skip'], + ['er_r2_k3_s2_e4_c32'], + ['er_r2_k3_s2_e4_c48'], + ['ir_r3_k3_s2_e4_c96_se0.25'], + ['ir_r5_k3_s1_e6_c112_se0.25'], + ['ir_r8_k3_s2_e6_c192_se0.25'], + ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_chs_fn(1280), + stem_size=32, + round_chs_fn=round_chs_fn, + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_s( + variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, rw=False, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 Small model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + + NOTE: `rw` flag sets up 'small' variant to behave like my initial v2 small model, + before ref the impl was released. + """ + arch_def = [ + ['cn_r2_k3_s1_e1_c24_skip'], + ['er_r4_k3_s2_e4_c48'], + ['er_r4_k3_s2_e4_c64'], + ['ir_r6_k3_s2_e4_c128_se0.25'], + ['ir_r9_k3_s1_e6_c160_se0.25'], + ['ir_r15_k3_s2_e6_c256_se0.25'], + ] + num_features = 1280 + if rw: + # my original variant, based on paper figure differs from the official release + arch_def[0] = ['er_r2_k3_s1_e1_c24'] + arch_def[-1] = ['ir_r15_k3_s2_e6_c272_se0.25'] + num_features = 1792 + + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size), + num_features=round_chs_fn(num_features), + stem_size=24, + round_chs_fn=round_chs_fn, + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 Medium model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + """ + + arch_def = [ + ['cn_r3_k3_s1_e1_c24_skip'], + ['er_r5_k3_s2_e4_c48'], + ['er_r5_k3_s2_e4_c80'], + ['ir_r7_k3_s2_e4_c160_se0.25'], + ['ir_r14_k3_s1_e6_c176_se0.25'], + ['ir_r18_k3_s2_e6_c304_se0.25'], + ['ir_r5_k3_s1_e6_c512_se0.25'], + ] + + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=1280, + stem_size=24, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 Large model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + """ + + arch_def = [ + ['cn_r4_k3_s1_e1_c32_skip'], + ['er_r7_k3_s2_e4_c64'], + ['er_r7_k3_s2_e4_c96'], + ['ir_r10_k3_s2_e4_c192_se0.25'], + ['ir_r19_k3_s1_e6_c224_se0.25'], + ['ir_r25_k3_s2_e6_c384_se0.25'], + ['ir_r7_k3_s1_e6_c640_se0.25'], + ] + + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=1280, + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 Xtra-Large model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + """ + + arch_def = [ + ['cn_r4_k3_s1_e1_c32_skip'], + ['er_r8_k3_s2_e4_c64'], + ['er_r8_k3_s2_e4_c96'], + ['ir_r16_k3_s2_e4_c192_se0.25'], + ['ir_r24_k3_s1_e6_c256_se0.25'], + ['ir_r32_k3_s2_e6_c512_se0.25'], + ['ir_r8_k3_s1_e6_c640_se0.25'], + ] + + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=1280, + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Small model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=1536, + stem_size=16, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Medium-Large model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c24'], # relu + # stage 1, 112x112 in + ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + num_features=1536, + stem_size=24, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_tinynet( + variant, model_width=1.0, depth_multiplier=1.0, pretrained=False, **kwargs +): + """Creates a TinyNet model. + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + num_features=max(1280, round_channels(1280, model_width, 8, None)), + stem_size=32, + fix_stem=True, + round_chs_fn=partial(round_channels, multiplier=model_width), + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'mnasnet_050.untrained': _cfg(), + 'mnasnet_075.untrained': _cfg(), + 'mnasnet_100.rmsp_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth', + hf_hub_id='timm/'), + 'mnasnet_140.untrained': _cfg(), + + 'semnasnet_050.untrained': _cfg(), + 'semnasnet_075.rmsp_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/semnasnet_075-18710866.pth', + hf_hub_id='timm/'), + 'semnasnet_100.rmsp_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth', + hf_hub_id='timm/'), + 'semnasnet_140.untrained': _cfg(), + 'mnasnet_small.lamb_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_small_lamb-aff75073.pth', + hf_hub_id='timm/'), + + 'mobilenetv2_035.untrained': _cfg(), + 'mobilenetv2_050.lamb_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_050-3d30d450.pth', + hf_hub_id='timm/', + interpolation='bicubic', + ), + 'mobilenetv2_075.untrained': _cfg(), + 'mobilenetv2_100.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth', + hf_hub_id='timm/'), + 'mobilenetv2_110d.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth', + hf_hub_id='timm/'), + 'mobilenetv2_120d.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth', + hf_hub_id='timm/'), + 'mobilenetv2_140.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth', + hf_hub_id='timm/'), + + 'fbnetc_100.rmsp_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', + hf_hub_id='timm/', + interpolation='bilinear'), + 'spnasnet_100.rmsp_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', + hf_hub_id='timm/', + interpolation='bilinear'), + + # NOTE experimenting with alternate attention + 'efficientnet_b0.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth', + hf_hub_id='timm/'), + 'efficientnet_b1.ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', + hf_hub_id='timm/', + test_input_size=(3, 256, 256), crop_pct=1.0), + 'efficientnet_b2.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), test_input_size=(3, 288, 288), crop_pct=1.0), + 'efficientnet_b3.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth', + hf_hub_id='timm/', + input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), + 'efficientnet_b4.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b4_ra2_320-7eb33cd5.pth', + hf_hub_id='timm/', + input_size=(3, 320, 320), pool_size=(10, 10), test_input_size=(3, 384, 384), crop_pct=1.0), + 'efficientnet_b5.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, crop_mode='squash'), + 'efficientnet_b5.sw_in12k': _cfg( + hf_hub_id='timm/', + input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.95, num_classes=11821), + 'efficientnet_b6.untrained': _cfg( + url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'efficientnet_b7.untrained': _cfg( + url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'efficientnet_b8.untrained': _cfg( + url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + 'efficientnet_l2.untrained': _cfg( + url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961), + + # FIXME experimental + 'efficientnet_b0_gn.untrained': _cfg(), + 'efficientnet_b0_g8_gn.untrained': _cfg(), + 'efficientnet_b0_g16_evos.untrained': _cfg(), + 'efficientnet_b3_gn.untrained': _cfg( + input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), + 'efficientnet_b3_g8_gn.untrained': _cfg( + input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), + + 'efficientnet_es.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth', + hf_hub_id='timm/'), + 'efficientnet_em.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_em_ra2-66250f76.pth', + hf_hub_id='timm/', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_el.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_el-3b455510.pth', + hf_hub_id='timm/', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + + 'efficientnet_es_pruned.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_pruned75-1b7248cf.pth', + hf_hub_id='timm/'), + 'efficientnet_el_pruned.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_el_pruned70-ef2a2ccf.pth', + hf_hub_id='timm/', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + + 'efficientnet_cc_b0_4e.untrained': _cfg(), + 'efficientnet_cc_b0_8e.untrained': _cfg(), + 'efficientnet_cc_b1_8e.untrained': _cfg(input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + + 'efficientnet_lite0.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth', + hf_hub_id='timm/'), + 'efficientnet_lite1.untrained': _cfg( + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_lite2.untrained': _cfg( + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'efficientnet_lite3.untrained': _cfg( + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'efficientnet_lite4.untrained': _cfg( + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + + 'efficientnet_b1_pruned.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/effnetb1_pruned-bea43a3a.pth', + hf_hub_id='timm/', + input_size=(3, 240, 240), pool_size=(8, 8), + crop_pct=0.882, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'efficientnet_b2_pruned.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/effnetb2_pruned-08c1b27c.pth', + hf_hub_id='timm/', + input_size=(3, 260, 260), pool_size=(9, 9), + crop_pct=0.890, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'efficientnet_b3_pruned.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/effnetb3_pruned-59ecf72d.pth', + hf_hub_id='timm/', + input_size=(3, 300, 300), pool_size=(10, 10), + crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + + 'efficientnetv2_rw_t.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnetv2_t_agc-3620981a.pth', + hf_hub_id='timm/', + input_size=(3, 224, 224), test_input_size=(3, 288, 288), pool_size=(7, 7), crop_pct=1.0), + 'gc_efficientnetv2_rw_t.agc_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gc_efficientnetv2_rw_t_agc-927a0bde.pth', + hf_hub_id='timm/', + input_size=(3, 224, 224), test_input_size=(3, 288, 288), pool_size=(7, 7), crop_pct=1.0), + 'efficientnetv2_rw_s.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth', + hf_hub_id='timm/', + input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), + 'efficientnetv2_rw_m.agc_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnetv2_rw_m_agc-3d90cb1e.pth', + hf_hub_id='timm/', + input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0), + + 'efficientnetv2_s.untrained': _cfg( + input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), + 'efficientnetv2_m.untrained': _cfg( + input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0), + 'efficientnetv2_l.untrained': _cfg( + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + 'efficientnetv2_xl.untrained': _cfg( + input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0), + + 'tf_efficientnet_b0.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth', + hf_hub_id='timm/', + input_size=(3, 224, 224)), + 'tf_efficientnet_b1.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth', + hf_hub_id='timm/', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth', + hf_hub_id='timm/', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth', + hf_hub_id='timm/', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth', + hf_hub_id='timm/', + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth', + hf_hub_id='timm/', + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth', + hf_hub_id='timm/', + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth', + hf_hub_id='timm/', + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_l2.ns_jft_in1k_475': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth', + hf_hub_id='timm/', + input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936), + 'tf_efficientnet_l2.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth', + hf_hub_id='timm/', + input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96), + + 'tf_efficientnet_b0.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 224, 224)), + 'tf_efficientnet_b1.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_b8.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + + 'tf_efficientnet_b5.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', + hf_hub_id='timm/', + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b7.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', + hf_hub_id='timm/', + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_b8.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', + hf_hub_id='timm/', + input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + + 'tf_efficientnet_b0.aa_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', + hf_hub_id='timm/', + input_size=(3, 224, 224)), + 'tf_efficientnet_b1.aa_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', + hf_hub_id='timm/', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2.aa_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', + hf_hub_id='timm/', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3.aa_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', + hf_hub_id='timm/', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4.aa_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', + hf_hub_id='timm/', + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5.aa_in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_aa-99018a74.pth', + hf_hub_id='timm/', + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6.aa_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', + hf_hub_id='timm/', + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7.aa_in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_aa-076e3472.pth', + hf_hub_id='timm/', + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + + 'tf_efficientnet_b0.in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth', + hf_hub_id='timm/', + input_size=(3, 224, 224)), + 'tf_efficientnet_b1.in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4.pth', + hf_hub_id='timm/', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2.in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth', + hf_hub_id='timm/', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3.in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth', + hf_hub_id='timm/', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4.in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth', + hf_hub_id='timm/', + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5.in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth', + hf_hub_id='timm/', + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + + + 'tf_efficientnet_es.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 224, 224), ), + 'tf_efficientnet_em.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_el.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + + 'tf_efficientnet_cc_b0_4e.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_efficientnet_cc_b0_8e.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_efficientnet_cc_b1_8e.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + + 'tf_efficientnet_lite0.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite1.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite2.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite3.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, interpolation='bilinear'), + 'tf_efficientnet_lite4.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'), + + 'tf_efficientnetv2_s.in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), + 'tf_efficientnetv2_m.in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21ft1k-bf41664a.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'tf_efficientnetv2_l.in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21ft1k-60127a9d.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'tf_efficientnetv2_xl.in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21ft1k-06c35c48.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + 'tf_efficientnetv2_s.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), + 'tf_efficientnetv2_m.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'tf_efficientnetv2_l.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + 'tf_efficientnetv2_s.in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), + 'tf_efficientnetv2_m.in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21k-361418a2.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'tf_efficientnetv2_l.in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21k-91a19ec9.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'tf_efficientnetv2_xl.in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21k-fd7e8abf.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + 'tf_efficientnetv2_b0.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b0-c7cc451f.pth', + hf_hub_id='timm/', + input_size=(3, 192, 192), test_input_size=(3, 224, 224), pool_size=(6, 6)), + 'tf_efficientnetv2_b1.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b1-be6e41b0.pth', + hf_hub_id='timm/', + input_size=(3, 192, 192), test_input_size=(3, 240, 240), pool_size=(6, 6), crop_pct=0.882), + 'tf_efficientnetv2_b2.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b2-847de54e.pth', + hf_hub_id='timm/', + input_size=(3, 208, 208), test_input_size=(3, 260, 260), pool_size=(7, 7), crop_pct=0.890), + 'tf_efficientnetv2_b3.in21k_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.9, crop_mode='squash'), + 'tf_efficientnetv2_b3.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b3-57773f13.pth', + hf_hub_id='timm/', + input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.904), + 'tf_efficientnetv2_b3.in21k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=21843, + input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.904), + + 'mixnet_s.ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth', + hf_hub_id='timm/'), + 'mixnet_m.ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth', + hf_hub_id='timm/'), + 'mixnet_l.ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth', + hf_hub_id='timm/'), + 'mixnet_xl.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth', + hf_hub_id='timm/'), + 'mixnet_xxl.untrained': _cfg(), + + 'tf_mixnet_s.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth', + hf_hub_id='timm/'), + 'tf_mixnet_m.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth', + hf_hub_id='timm/'), + 'tf_mixnet_l.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth', + hf_hub_id='timm/'), + + "tinynet_a.in1k": _cfg( + input_size=(3, 192, 192), pool_size=(6, 6), # int(224 * 0.86) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_a.pth', + hf_hub_id='timm/'), + "tinynet_b.in1k": _cfg( + input_size=(3, 188, 188), pool_size=(6, 6), # int(224 * 0.84) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_b.pth', + hf_hub_id='timm/'), + "tinynet_c.in1k": _cfg( + input_size=(3, 184, 184), pool_size=(6, 6), # int(224 * 0.825) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_c.pth', + hf_hub_id='timm/'), + "tinynet_d.in1k": _cfg( + input_size=(3, 152, 152), pool_size=(5, 5), # int(224 * 0.68) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_d.pth', + hf_hub_id='timm/'), + "tinynet_e.in1k": _cfg( + input_size=(3, 106, 106), pool_size=(4, 4), # int(224 * 0.475) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_e.pth', + hf_hub_id='timm/'), +}) + + +@register_model +def mnasnet_050(pretrained=False, **kwargs) -> EfficientNet: + """ MNASNet B1, depth multiplier of 0.5. """ + model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_075(pretrained=False, **kwargs) -> EfficientNet: + """ MNASNet B1, depth multiplier of 0.75. """ + model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_100(pretrained=False, **kwargs) -> EfficientNet: + """ MNASNet B1, depth multiplier of 1.0. """ + model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_140(pretrained=False, **kwargs) -> EfficientNet: + """ MNASNet B1, depth multiplier of 1.4 """ + model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_050(pretrained=False, **kwargs) -> EfficientNet: + """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ + model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_075(pretrained=False, **kwargs) -> EfficientNet: + """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ + model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_100(pretrained=False, **kwargs) -> EfficientNet: + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_140(pretrained=False, **kwargs) -> EfficientNet: + """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ + model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_small(pretrained=False, **kwargs) -> EfficientNet: + """ MNASNet Small, depth multiplier of 1.0. """ + model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_035(pretrained=False, **kwargs) -> EfficientNet: + """ MobileNet V2 w/ 0.35 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_035', 0.35, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_050(pretrained=False, **kwargs) -> EfficientNet: + """ MobileNet V2 w/ 0.5 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_075(pretrained=False, **kwargs) -> EfficientNet: + """ MobileNet V2 w/ 0.75 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_100(pretrained=False, **kwargs) -> EfficientNet: + """ MobileNet V2 w/ 1.0 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_140(pretrained=False, **kwargs) -> EfficientNet: + """ MobileNet V2 w/ 1.4 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_110d(pretrained=False, **kwargs) -> EfficientNet: + """ MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers""" + model = _gen_mobilenet_v2( + 'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_120d(pretrained=False, **kwargs) -> EfficientNet: + """ MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """ + model = _gen_mobilenet_v2( + 'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetc_100(pretrained=False, **kwargs) -> EfficientNet: + """ FBNet-C """ + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def spnasnet_100(pretrained=False, **kwargs) -> EfficientNet: + """ Single-Path NAS Pixel1""" + model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b0(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B0 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b1(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B1 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b2(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B2 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b3(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b4(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B4 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b5(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B5 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b6(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B6 """ + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b7(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B7 """ + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b8(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B8 """ + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_l2(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-L2.""" + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +# FIXME experimental group cong / GroupNorm / EvoNorm experiments +@register_model +def efficientnet_b0_gn(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B0 + GroupNorm""" + model = _gen_efficientnet( + 'efficientnet_b0_gn', norm_layer=partial(GroupNormAct, group_size=8), pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b0_g8_gn(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B0 w/ group conv + GroupNorm""" + model = _gen_efficientnet( + 'efficientnet_b0_g8_gn', group_size=8, norm_layer=partial(GroupNormAct, group_size=8), + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b0_g16_evos(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B0 w/ group 16 conv + EvoNorm""" + model = _gen_efficientnet( + 'efficientnet_b0_g16_evos', group_size=16, channel_divisor=16, + pretrained=pretrained, **kwargs) #norm_layer=partial(EvoNorm2dS0, group_size=16), + return model + + +@register_model +def efficientnet_b3_gn(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B3 w/ GroupNorm """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b3_gn', channel_multiplier=1.2, depth_multiplier=1.4, channel_divisor=16, + norm_layer=partial(GroupNormAct, group_size=16), pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b3_g8_gn(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B3 w/ grouped conv + BN""" + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b3_g8_gn', channel_multiplier=1.2, depth_multiplier=1.4, group_size=8, channel_divisor=16, + norm_layer=partial(GroupNormAct, group_size=16), pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_es(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Edge Small. """ + model = _gen_efficientnet_edge( + 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_es_pruned(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Edge Small Pruned. For more info: https://github.com/DeGirum/pruned-models/releases/tag/efficientnet_v1.0""" + model = _gen_efficientnet_edge( + 'efficientnet_es_pruned', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + +@register_model +def efficientnet_em(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Edge-Medium. """ + model = _gen_efficientnet_edge( + 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_el(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Edge-Large. """ + model = _gen_efficientnet_edge( + 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + +@register_model +def efficientnet_el_pruned(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Edge-Large pruned. For more info: https://github.com/DeGirum/pruned-models/releases/tag/efficientnet_v1.0""" + model = _gen_efficientnet_edge( + 'efficientnet_el_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + +@register_model +def efficientnet_cc_b0_4e(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_cc_b0_8e(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_cc_b1_8e(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite0(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Lite0 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite1(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Lite1 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite2(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Lite2 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite3(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Lite3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite4(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Lite4 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b1_pruned(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + variant = 'efficientnet_b1_pruned' + model = _gen_efficientnet( + variant, channel_multiplier=1.0, depth_multiplier=1.1, pruned=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b2_pruned(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B2 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet( + 'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pruned=True, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b3_pruned(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B3 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet( + 'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pruned=True, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_rw_t(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2 Tiny (Custom variant, tiny not in paper). """ + model = _gen_efficientnetv2_s( + 'efficientnetv2_rw_t', channel_multiplier=0.8, depth_multiplier=0.9, rw=False, pretrained=pretrained, **kwargs) + return model + + +@register_model +def gc_efficientnetv2_rw_t(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2 Tiny w/ Global Context Attn (Custom variant, tiny not in paper). """ + model = _gen_efficientnetv2_s( + 'gc_efficientnetv2_rw_t', channel_multiplier=0.8, depth_multiplier=0.9, + rw=False, se_layer='gc', pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_rw_s(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2 Small (RW variant). + NOTE: This is my initial (pre official code release) w/ some differences. + See efficientnetv2_s and tf_efficientnetv2_s for versions that match the official w/ PyTorch vs TF padding + """ + model = _gen_efficientnetv2_s('efficientnetv2_rw_s', rw=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_rw_m(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2 Medium (RW variant). + """ + model = _gen_efficientnetv2_s( + 'efficientnetv2_rw_m', channel_multiplier=1.2, depth_multiplier=(1.2,) * 4 + (1.6,) * 2, rw=True, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_s(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2 Small. """ + model = _gen_efficientnetv2_s('efficientnetv2_s', pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_m(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2 Medium. """ + model = _gen_efficientnetv2_m('efficientnetv2_m', pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_l(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2 Large. """ + model = _gen_efficientnetv2_l('efficientnetv2_l', pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_xl(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2 Xtra-Large. """ + model = _gen_efficientnetv2_xl('efficientnetv2_xl', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b0(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B0. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet( + 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b1(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B1. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet( + 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b2(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B2. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet( + 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b3(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B3. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet( + 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b4(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B4. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet( + 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b5(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B5. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet( + 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b6(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B6. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet( + 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b7(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B7. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet( + 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b8(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B8. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet( + 'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_l2(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet( + 'tf_efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_es(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Edge Small. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet_edge( + 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_em(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet_edge( + 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_el(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Edge-Large. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet_edge( + 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-CondConv-B0 w/ 4 Experts. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-CondConv-B0 w/ 8 Experts. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite0(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Lite0 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite1(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Lite1 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite2(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Lite2 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite3(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Lite3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite4(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-Lite4 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_s(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2 Small. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnetv2_s('tf_efficientnetv2_s', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_m(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2 Medium. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnetv2_m('tf_efficientnetv2_m', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_l(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2 Large. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnetv2_l('tf_efficientnetv2_l', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_xl(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2 Xtra-Large. Tensorflow compatible variant + """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnetv2_xl('tf_efficientnetv2_xl', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b0(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2-B0. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnetv2_base('tf_efficientnetv2_b0', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b1(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2-B1. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnetv2_base( + 'tf_efficientnetv2_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b2(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2-B2. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnetv2_base( + 'tf_efficientnetv2_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b3(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-V2-B3. Tensorflow compatible variant """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_efficientnetv2_base( + 'tf_efficientnetv2_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_s(pretrained=False, **kwargs) -> EfficientNet: + """Creates a MixNet Small model. + """ + model = _gen_mixnet_s( + 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_m(pretrained=False, **kwargs) -> EfficientNet: + """Creates a MixNet Medium model. + """ + model = _gen_mixnet_m( + 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_l(pretrained=False, **kwargs) -> EfficientNet: + """Creates a MixNet Large model. + """ + model = _gen_mixnet_m( + 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_xl(pretrained=False, **kwargs) -> EfficientNet: + """Creates a MixNet Extra-Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + model = _gen_mixnet_m( + 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_xxl(pretrained=False, **kwargs) -> EfficientNet: + """Creates a MixNet Double Extra Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + model = _gen_mixnet_m( + 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mixnet_s(pretrained=False, **kwargs) -> EfficientNet: + """Creates a MixNet Small model. Tensorflow compatible variant + """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_mixnet_s( + 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mixnet_m(pretrained=False, **kwargs) -> EfficientNet: + """Creates a MixNet Medium model. Tensorflow compatible variant + """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_mixnet_m( + 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mixnet_l(pretrained=False, **kwargs) -> EfficientNet: + """Creates a MixNet Large model. Tensorflow compatible variant + """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_mixnet_m( + 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tinynet_a(pretrained=False, **kwargs) -> EfficientNet: + model = _gen_tinynet('tinynet_a', 1.0, 1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tinynet_b(pretrained=False, **kwargs) -> EfficientNet: + model = _gen_tinynet('tinynet_b', 0.75, 1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tinynet_c(pretrained=False, **kwargs) -> EfficientNet: + model = _gen_tinynet('tinynet_c', 0.54, 0.85, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tinynet_d(pretrained=False, **kwargs) -> EfficientNet: + model = _gen_tinynet('tinynet_d', 0.54, 0.695, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tinynet_e(pretrained=False, **kwargs) -> EfficientNet: + model = _gen_tinynet('tinynet_e', 0.51, 0.6, pretrained=pretrained, **kwargs) + return model + + +register_model_deprecations(__name__, { + 'tf_efficientnet_b0_ap': 'tf_efficientnet_b0.ap_in1k', + 'tf_efficientnet_b1_ap': 'tf_efficientnet_b1.ap_in1k', + 'tf_efficientnet_b2_ap': 'tf_efficientnet_b2.ap_in1k', + 'tf_efficientnet_b3_ap': 'tf_efficientnet_b3.ap_in1k', + 'tf_efficientnet_b4_ap': 'tf_efficientnet_b4.ap_in1k', + 'tf_efficientnet_b5_ap': 'tf_efficientnet_b5.ap_in1k', + 'tf_efficientnet_b6_ap': 'tf_efficientnet_b6.ap_in1k', + 'tf_efficientnet_b7_ap': 'tf_efficientnet_b7.ap_in1k', + 'tf_efficientnet_b8_ap': 'tf_efficientnet_b8.ap_in1k', + 'tf_efficientnet_b0_ns': 'tf_efficientnet_b0.ns_jft_in1k', + 'tf_efficientnet_b1_ns': 'tf_efficientnet_b1.ns_jft_in1k', + 'tf_efficientnet_b2_ns': 'tf_efficientnet_b2.ns_jft_in1k', + 'tf_efficientnet_b3_ns': 'tf_efficientnet_b3.ns_jft_in1k', + 'tf_efficientnet_b4_ns': 'tf_efficientnet_b4.ns_jft_in1k', + 'tf_efficientnet_b5_ns': 'tf_efficientnet_b5.ns_jft_in1k', + 'tf_efficientnet_b6_ns': 'tf_efficientnet_b6.ns_jft_in1k', + 'tf_efficientnet_b7_ns': 'tf_efficientnet_b7.ns_jft_in1k', + 'tf_efficientnet_l2_ns_475': 'tf_efficientnet_l2.ns_jft_in1k_475', + 'tf_efficientnet_l2_ns': 'tf_efficientnet_l2.ns_jft_in1k', + 'tf_efficientnetv2_s_in21ft1k': 'tf_efficientnetv2_s.in21k_ft_in1k', + 'tf_efficientnetv2_m_in21ft1k': 'tf_efficientnetv2_m.in21k_ft_in1k', + 'tf_efficientnetv2_l_in21ft1k': 'tf_efficientnetv2_l.in21k_ft_in1k', + 'tf_efficientnetv2_xl_in21ft1k': 'tf_efficientnetv2_xl.in21k_ft_in1k', + 'tf_efficientnetv2_s_in21k': 'tf_efficientnetv2_s.in21k', + 'tf_efficientnetv2_m_in21k': 'tf_efficientnetv2_m.in21k', + 'tf_efficientnetv2_l_in21k': 'tf_efficientnetv2_l.in21k', + 'tf_efficientnetv2_xl_in21k': 'tf_efficientnetv2_xl.in21k', + 'efficientnet_b2a': 'efficientnet_b2', + 'efficientnet_b3a': 'efficientnet_b3', + 'mnasnet_a1': 'semnasnet_100', + 'mnasnet_b1': 'mnasnet_100', +}) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/efficientvit_mit.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/efficientvit_mit.py new file mode 100644 index 0000000000000000000000000000000000000000..1960d3d24c903a9a213f69411eb3fee4204b62ca --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/efficientvit_mit.py @@ -0,0 +1,1098 @@ +""" EfficientViT (by MIT Song Han's Lab) + +Paper: `Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition` + - https://arxiv.org/abs/2205.14756 + +Adapted from official impl at https://github.com/mit-han-lab/efficientvit +""" + +__all__ = ['EfficientVit'] +from typing import Optional +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.batchnorm import _BatchNorm + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs + + +def val2list(x: list or tuple or any, repeat_time=1): + if isinstance(x, (list, tuple)): + return list(x) + return [x for _ in range(repeat_time)] + + +def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1): + # repeat elements if necessary + x = val2list(x) + if len(x) > 0: + x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] + + return tuple(x) + + +def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]: + if isinstance(kernel_size, tuple): + return tuple([get_same_padding(ks) for ks in kernel_size]) + else: + assert kernel_size % 2 > 0, "kernel size should be odd number" + return kernel_size // 2 + + +class ConvNormAct(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + dilation=1, + groups=1, + bias=False, + dropout=0., + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU, + ): + super(ConvNormAct, self).__init__() + self.dropout = nn.Dropout(dropout, inplace=False) + self.conv = create_conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + ) + self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity() + self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity() + + def forward(self, x): + x = self.dropout(x) + x = self.conv(x) + x = self.norm(x) + x = self.act(x) + return x + + +class DSConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + use_bias=False, + norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), + act_layer=(nn.ReLU6, None), + ): + super(DSConv, self).__init__() + use_bias = val2tuple(use_bias, 2) + norm_layer = val2tuple(norm_layer, 2) + act_layer = val2tuple(act_layer, 2) + + self.depth_conv = ConvNormAct( + in_channels, + in_channels, + kernel_size, + stride, + groups=in_channels, + norm_layer=norm_layer[0], + act_layer=act_layer[0], + bias=use_bias[0], + ) + self.point_conv = ConvNormAct( + in_channels, + out_channels, + 1, + norm_layer=norm_layer[1], + act_layer=act_layer[1], + bias=use_bias[1], + ) + + def forward(self, x): + x = self.depth_conv(x) + x = self.point_conv(x) + return x + + +class ConvBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=1, + use_bias=False, + norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), + act_layer=(nn.ReLU6, None), + ): + super(ConvBlock, self).__init__() + use_bias = val2tuple(use_bias, 2) + norm_layer = val2tuple(norm_layer, 2) + act_layer = val2tuple(act_layer, 2) + mid_channels = mid_channels or round(in_channels * expand_ratio) + + self.conv1 = ConvNormAct( + in_channels, + mid_channels, + kernel_size, + stride, + norm_layer=norm_layer[0], + act_layer=act_layer[0], + bias=use_bias[0], + ) + self.conv2 = ConvNormAct( + mid_channels, + out_channels, + kernel_size, + 1, + norm_layer=norm_layer[1], + act_layer=act_layer[1], + bias=use_bias[1], + ) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class MBConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=6, + use_bias=False, + norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d), + act_layer=(nn.ReLU6, nn.ReLU6, None), + ): + super(MBConv, self).__init__() + use_bias = val2tuple(use_bias, 3) + norm_layer = val2tuple(norm_layer, 3) + act_layer = val2tuple(act_layer, 3) + mid_channels = mid_channels or round(in_channels * expand_ratio) + + self.inverted_conv = ConvNormAct( + in_channels, + mid_channels, + 1, + stride=1, + norm_layer=norm_layer[0], + act_layer=act_layer[0], + bias=use_bias[0], + ) + self.depth_conv = ConvNormAct( + mid_channels, + mid_channels, + kernel_size, + stride=stride, + groups=mid_channels, + norm_layer=norm_layer[1], + act_layer=act_layer[1], + bias=use_bias[1], + ) + self.point_conv = ConvNormAct( + mid_channels, + out_channels, + 1, + norm_layer=norm_layer[2], + act_layer=act_layer[2], + bias=use_bias[2], + ) + + def forward(self, x): + x = self.inverted_conv(x) + x = self.depth_conv(x) + x = self.point_conv(x) + return x + + +class FusedMBConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=6, + groups=1, + use_bias=False, + norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), + act_layer=(nn.ReLU6, None), + ): + super(FusedMBConv, self).__init__() + use_bias = val2tuple(use_bias, 2) + norm_layer = val2tuple(norm_layer, 2) + act_layer = val2tuple(act_layer, 2) + mid_channels = mid_channels or round(in_channels * expand_ratio) + + self.spatial_conv = ConvNormAct( + in_channels, + mid_channels, + kernel_size, + stride=stride, + groups=groups, + norm_layer=norm_layer[0], + act_layer=act_layer[0], + bias=use_bias[0], + ) + self.point_conv = ConvNormAct( + mid_channels, + out_channels, + 1, + norm_layer=norm_layer[1], + act_layer=act_layer[1], + bias=use_bias[1], + ) + + def forward(self, x): + x = self.spatial_conv(x) + x = self.point_conv(x) + return x + + +class LiteMLA(nn.Module): + """Lightweight multi-scale linear attention""" + + def __init__( + self, + in_channels: int, + out_channels: int, + heads: int or None = None, + heads_ratio: float = 1.0, + dim=8, + use_bias=False, + norm_layer=(None, nn.BatchNorm2d), + act_layer=(None, None), + kernel_func=nn.ReLU, + scales=(5,), + eps=1e-5, + ): + super(LiteMLA, self).__init__() + self.eps = eps + heads = heads or int(in_channels // dim * heads_ratio) + total_dim = heads * dim + use_bias = val2tuple(use_bias, 2) + norm_layer = val2tuple(norm_layer, 2) + act_layer = val2tuple(act_layer, 2) + + self.dim = dim + self.qkv = ConvNormAct( + in_channels, + 3 * total_dim, + 1, + bias=use_bias[0], + norm_layer=norm_layer[0], + act_layer=act_layer[0], + ) + self.aggreg = nn.ModuleList([ + nn.Sequential( + nn.Conv2d( + 3 * total_dim, + 3 * total_dim, + scale, + padding=get_same_padding(scale), + groups=3 * total_dim, + bias=use_bias[0], + ), + nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]), + ) + for scale in scales + ]) + self.kernel_func = kernel_func(inplace=False) + + self.proj = ConvNormAct( + total_dim * (1 + len(scales)), + out_channels, + 1, + bias=use_bias[1], + norm_layer=norm_layer[1], + act_layer=act_layer[1], + ) + + def _attn(self, q, k, v): + dtype = v.dtype + q, k, v = q.float(), k.float(), v.float() + kv = k.transpose(-1, -2) @ v + out = q @ kv + out = out[..., :-1] / (out[..., -1:] + self.eps) + return out.to(dtype) + + def forward(self, x): + B, _, H, W = x.shape + + # generate multi-scale q, k, v + qkv = self.qkv(x) + multi_scale_qkv = [qkv] + for op in self.aggreg: + multi_scale_qkv.append(op(qkv)) + multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1) + multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2) + q, k, v = multi_scale_qkv.chunk(3, dim=-1) + + # lightweight global attention + q = self.kernel_func(q) + k = self.kernel_func(k) + v = F.pad(v, (0, 1), mode="constant", value=1.) + + if not torch.jit.is_scripting(): + with torch.autocast(device_type=v.device.type, enabled=False): + out = self._attn(q, k, v) + else: + out = self._attn(q, k, v) + + # final projection + out = out.transpose(-1, -2).reshape(B, -1, H, W) + out = self.proj(out) + return out + + +register_notrace_module(LiteMLA) + + +class EfficientVitBlock(nn.Module): + def __init__( + self, + in_channels, + heads_ratio=1.0, + head_dim=32, + expand_ratio=4, + norm_layer=nn.BatchNorm2d, + act_layer=nn.Hardswish, + ): + super(EfficientVitBlock, self).__init__() + self.context_module = ResidualBlock( + LiteMLA( + in_channels=in_channels, + out_channels=in_channels, + heads_ratio=heads_ratio, + dim=head_dim, + norm_layer=(None, norm_layer), + ), + nn.Identity(), + ) + self.local_module = ResidualBlock( + MBConv( + in_channels=in_channels, + out_channels=in_channels, + expand_ratio=expand_ratio, + use_bias=(True, True, False), + norm_layer=(None, None, norm_layer), + act_layer=(act_layer, act_layer, None), + ), + nn.Identity(), + ) + + def forward(self, x): + x = self.context_module(x) + x = self.local_module(x) + return x + + +class ResidualBlock(nn.Module): + def __init__( + self, + main: Optional[nn.Module], + shortcut: Optional[nn.Module] = None, + pre_norm: Optional[nn.Module] = None, + ): + super(ResidualBlock, self).__init__() + self.pre_norm = pre_norm if pre_norm is not None else nn.Identity() + self.main = main + self.shortcut = shortcut + + def forward(self, x): + res = self.main(self.pre_norm(x)) + if self.shortcut is not None: + res = res + self.shortcut(x) + return res + + +def build_local_block( + in_channels: int, + out_channels: int, + stride: int, + expand_ratio: float, + norm_layer: str, + act_layer: str, + fewer_norm: bool = False, + block_type: str = "default", +): + assert block_type in ["default", "large", "fused"] + if expand_ratio == 1: + if block_type == "default": + block = DSConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + use_bias=(True, False) if fewer_norm else False, + norm_layer=(None, norm_layer) if fewer_norm else norm_layer, + act_layer=(act_layer, None), + ) + else: + block = ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + use_bias=(True, False) if fewer_norm else False, + norm_layer=(None, norm_layer) if fewer_norm else norm_layer, + act_layer=(act_layer, None), + ) + else: + if block_type == "default": + block = MBConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + use_bias=(True, True, False) if fewer_norm else False, + norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer, + act_layer=(act_layer, act_layer, None), + ) + else: + block = FusedMBConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + use_bias=(True, False) if fewer_norm else False, + norm_layer=(None, norm_layer) if fewer_norm else norm_layer, + act_layer=(act_layer, None), + ) + return block + + +class Stem(nn.Sequential): + def __init__(self, in_chs, out_chs, depth, norm_layer, act_layer, block_type='default'): + super().__init__() + self.stride = 2 + + self.add_module( + 'in_conv', + ConvNormAct( + in_chs, out_chs, + kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer, + ) + ) + stem_block = 0 + for _ in range(depth): + self.add_module(f'res{stem_block}', ResidualBlock( + build_local_block( + in_channels=out_chs, + out_channels=out_chs, + stride=1, + expand_ratio=1, + norm_layer=norm_layer, + act_layer=act_layer, + block_type=block_type, + ), + nn.Identity(), + )) + stem_block += 1 + + +class EfficientVitStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + depth, + norm_layer, + act_layer, + expand_ratio, + head_dim, + vit_stage=False, + ): + super(EfficientVitStage, self).__init__() + blocks = [ResidualBlock( + build_local_block( + in_channels=in_chs, + out_channels=out_chs, + stride=2, + expand_ratio=expand_ratio, + norm_layer=norm_layer, + act_layer=act_layer, + fewer_norm=vit_stage, + ), + None, + )] + in_chs = out_chs + + if vit_stage: + # for stage 3, 4 + for _ in range(depth): + blocks.append( + EfficientVitBlock( + in_channels=in_chs, + head_dim=head_dim, + expand_ratio=expand_ratio, + norm_layer=norm_layer, + act_layer=act_layer, + ) + ) + else: + # for stage 1, 2 + for i in range(1, depth): + blocks.append(ResidualBlock( + build_local_block( + in_channels=in_chs, + out_channels=out_chs, + stride=1, + expand_ratio=expand_ratio, + norm_layer=norm_layer, + act_layer=act_layer + ), + nn.Identity(), + )) + + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + return self.blocks(x) + + +class EfficientVitLargeStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + depth, + norm_layer, + act_layer, + head_dim, + vit_stage=False, + fewer_norm=False, + ): + super(EfficientVitLargeStage, self).__init__() + blocks = [ResidualBlock( + build_local_block( + in_channels=in_chs, + out_channels=out_chs, + stride=2, + expand_ratio=24 if vit_stage else 16, + norm_layer=norm_layer, + act_layer=act_layer, + fewer_norm=vit_stage or fewer_norm, + block_type='default' if fewer_norm else 'fused', + ), + None, + )] + in_chs = out_chs + + if vit_stage: + # for stage 4 + for _ in range(depth): + blocks.append( + EfficientVitBlock( + in_channels=in_chs, + head_dim=head_dim, + expand_ratio=6, + norm_layer=norm_layer, + act_layer=act_layer, + ) + ) + else: + # for stage 1, 2, 3 + for i in range(depth): + blocks.append(ResidualBlock( + build_local_block( + in_channels=in_chs, + out_channels=out_chs, + stride=1, + expand_ratio=4, + norm_layer=norm_layer, + act_layer=act_layer, + fewer_norm=fewer_norm, + block_type='default' if fewer_norm else 'fused', + ), + nn.Identity(), + )) + + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + return self.blocks(x) + + +class ClassifierHead(nn.Module): + def __init__( + self, + in_channels, + widths, + n_classes=1000, + dropout=0., + norm_layer=nn.BatchNorm2d, + act_layer=nn.Hardswish, + global_pool='avg', + norm_eps=1e-5, + ): + super(ClassifierHead, self).__init__() + self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW') + self.classifier = nn.Sequential( + nn.Linear(widths[0], widths[1], bias=False), + nn.LayerNorm(widths[1], eps=norm_eps), + act_layer(inplace=True) if act_layer is not None else nn.Identity(), + nn.Dropout(dropout, inplace=False), + nn.Linear(widths[1], n_classes, bias=True), + ) + + def forward(self, x, pre_logits: bool = False): + x = self.in_conv(x) + x = self.global_pool(x) + if pre_logits: + return x + x = self.classifier(x) + return x + + +class EfficientVit(nn.Module): + def __init__( + self, + in_chans=3, + widths=(), + depths=(), + head_dim=32, + expand_ratio=4, + norm_layer=nn.BatchNorm2d, + act_layer=nn.Hardswish, + global_pool='avg', + head_widths=(), + drop_rate=0.0, + num_classes=1000, + ): + super(EfficientVit, self).__init__() + self.grad_checkpointing = False + self.global_pool = global_pool + self.num_classes = num_classes + + # input stem + self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer) + stride = self.stem.stride + + # stages + self.feature_info = [] + self.stages = nn.Sequential() + in_channels = widths[0] + for i, (w, d) in enumerate(zip(widths[1:], depths[1:])): + self.stages.append(EfficientVitStage( + in_channels, + w, + depth=d, + norm_layer=norm_layer, + act_layer=act_layer, + expand_ratio=expand_ratio, + head_dim=head_dim, + vit_stage=i >= 2, + )) + stride *= 2 + in_channels = w + self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')] + + self.num_features = in_channels + self.head_widths = head_widths + self.head_dropout = drop_rate + if num_classes > 0: + self.head = ClassifierHead( + self.num_features, + self.head_widths, + n_classes=num_classes, + dropout=self.head_dropout, + global_pool=self.global_pool, + ) + else: + if self.global_pool == 'avg': + self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + else: + self.head = nn.Identity() + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.\w+\.(\d+)', None), + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.classifier[-1] + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + if num_classes > 0: + self.head = ClassifierHead( + self.num_features, + self.head_widths, + n_classes=num_classes, + dropout=self.head_dropout, + global_pool=self.global_pool, + ) + else: + if self.global_pool == 'avg': + self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) + else: + self.head = nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +class EfficientVitLarge(nn.Module): + def __init__( + self, + in_chans=3, + widths=(), + depths=(), + head_dim=32, + norm_layer=nn.BatchNorm2d, + act_layer=GELUTanh, + global_pool='avg', + head_widths=(), + drop_rate=0.0, + num_classes=1000, + norm_eps=1e-7, + ): + super(EfficientVitLarge, self).__init__() + self.grad_checkpointing = False + self.global_pool = global_pool + self.num_classes = num_classes + self.norm_eps = norm_eps + norm_layer = partial(norm_layer, eps=self.norm_eps) + + # input stem + self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, block_type='large') + stride = self.stem.stride + + # stages + self.feature_info = [] + self.stages = nn.Sequential() + in_channels = widths[0] + for i, (w, d) in enumerate(zip(widths[1:], depths[1:])): + self.stages.append(EfficientVitLargeStage( + in_channels, + w, + depth=d, + norm_layer=norm_layer, + act_layer=act_layer, + head_dim=head_dim, + vit_stage=i >= 3, + fewer_norm=i >= 2, + )) + stride *= 2 + in_channels = w + self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')] + + self.num_features = in_channels + self.head_widths = head_widths + self.head_dropout = drop_rate + if num_classes > 0: + self.head = ClassifierHead( + self.num_features, + self.head_widths, + n_classes=num_classes, + dropout=self.head_dropout, + global_pool=self.global_pool, + act_layer=act_layer, + norm_eps=self.norm_eps, + ) + else: + if self.global_pool == 'avg': + self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + else: + self.head = nn.Identity() + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.\w+\.(\d+)', None), + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.classifier[-1] + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + if num_classes > 0: + self.head = ClassifierHead( + self.num_features, + self.head_widths, + n_classes=num_classes, + dropout=self.head_dropout, + global_pool=self.global_pool, + norm_eps=self.norm_eps + ) + else: + if self.global_pool == 'avg': + self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) + else: + self.head = nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.in_conv.conv', + 'classifier': 'head.classifier.4', + 'crop_pct': 0.95, + 'input_size': (3, 224, 224), + 'pool_size': (7, 7), + **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + 'efficientvit_b0.r224_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientvit_b1.r224_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientvit_b1.r256_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, + ), + 'efficientvit_b1.r288_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, + ), + 'efficientvit_b2.r224_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientvit_b2.r256_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, + ), + 'efficientvit_b2.r288_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, + ), + 'efficientvit_b3.r224_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientvit_b3.r256_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, + ), + 'efficientvit_b3.r288_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, + ), + 'efficientvit_l1.r224_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=1.0, + ), + 'efficientvit_l2.r224_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=1.0, + ), + 'efficientvit_l2.r256_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, + ), + 'efficientvit_l2.r288_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, + ), + 'efficientvit_l2.r384_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + 'efficientvit_l3.r224_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=1.0, + ), + 'efficientvit_l3.r256_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, + ), + 'efficientvit_l3.r320_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, + ), + 'efficientvit_l3.r384_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + # 'efficientvit_l0_sam.sam': _cfg( + # # hf_hub_id='timm/', + # input_size=(3, 512, 512), crop_pct=1.0, + # num_classes=0, + # ), + # 'efficientvit_l1_sam.sam': _cfg( + # # hf_hub_id='timm/', + # input_size=(3, 512, 512), crop_pct=1.0, + # num_classes=0, + # ), + # 'efficientvit_l2_sam.sam': _cfg( + # # hf_hub_id='timm/',f + # input_size=(3, 512, 512), crop_pct=1.0, + # num_classes=0, + # ), +}) + + +def _create_efficientvit(variant, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) + model = build_model_with_cfg( + EfficientVit, + variant, + pretrained, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs + ) + return model + + +def _create_efficientvit_large(variant, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) + model = build_model_with_cfg( + EfficientVitLarge, + variant, + pretrained, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs + ) + return model + + +@register_model +def efficientvit_b0(pretrained=False, **kwargs): + model_args = dict( + widths=(8, 16, 32, 64, 128), depths=(1, 2, 2, 2, 2), head_dim=16, head_widths=(1024, 1280)) + return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_b1(pretrained=False, **kwargs): + model_args = dict( + widths=(16, 32, 64, 128, 256), depths=(1, 2, 3, 3, 4), head_dim=16, head_widths=(1536, 1600)) + return _create_efficientvit('efficientvit_b1', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_b2(pretrained=False, **kwargs): + model_args = dict( + widths=(24, 48, 96, 192, 384), depths=(1, 3, 4, 4, 6), head_dim=32, head_widths=(2304, 2560)) + return _create_efficientvit('efficientvit_b2', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_b3(pretrained=False, **kwargs): + model_args = dict( + widths=(32, 64, 128, 256, 512), depths=(1, 4, 6, 6, 9), head_dim=32, head_widths=(2304, 2560)) + return _create_efficientvit('efficientvit_b3', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_l1(pretrained=False, **kwargs): + model_args = dict( + widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, head_widths=(3072, 3200)) + return _create_efficientvit_large('efficientvit_l1', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_l2(pretrained=False, **kwargs): + model_args = dict( + widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, head_widths=(3072, 3200)) + return _create_efficientvit_large('efficientvit_l2', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_l3(pretrained=False, **kwargs): + model_args = dict( + widths=(64, 128, 256, 512, 1024), depths=(1, 2, 2, 8, 8), head_dim=32, head_widths=(6144, 6400)) + return _create_efficientvit_large('efficientvit_l3', pretrained=pretrained, **dict(model_args, **kwargs)) + + +# FIXME will wait for v2 SAM models which are pending +# @register_model +# def efficientvit_l0_sam(pretrained=False, **kwargs): +# # only backbone for segment-anything-model weights +# model_args = dict( +# widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 4, 4), head_dim=32, num_classes=0, norm_eps=1e-6) +# return _create_efficientvit_large('efficientvit_l0_sam', pretrained=pretrained, **dict(model_args, **kwargs)) +# +# +# @register_model +# def efficientvit_l1_sam(pretrained=False, **kwargs): +# # only backbone for segment-anything-model weights +# model_args = dict( +# widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, num_classes=0, norm_eps=1e-6) +# return _create_efficientvit_large('efficientvit_l1_sam', pretrained=pretrained, **dict(model_args, **kwargs)) +# +# +# @register_model +# def efficientvit_l2_sam(pretrained=False, **kwargs): +# # only backbone for segment-anything-model weights +# model_args = dict( +# widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, num_classes=0, norm_eps=1e-6) +# return _create_efficientvit_large('efficientvit_l2_sam', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/efficientvit_msra.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/efficientvit_msra.py new file mode 100644 index 0000000000000000000000000000000000000000..1b7f52a02f3b9f0b895d6301c2230fef82e0bd31 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/efficientvit_msra.py @@ -0,0 +1,659 @@ +""" EfficientViT (by MSRA) + +Paper: `EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention` + - https://arxiv.org/abs/2305.07027 + +Adapted from official impl at https://github.com/microsoft/Cream/tree/main/EfficientViT +""" + +__all__ = ['EfficientVitMsra'] +import itertools +from collections import OrderedDict +from typing import Dict + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs + + +class ConvNorm(torch.nn.Sequential): + def __init__(self, in_chs, out_chs, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): + super().__init__() + self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False) + self.bn = nn.BatchNorm2d(out_chs) + torch.nn.init.constant_(self.bn.weight, bn_weight_init) + torch.nn.init.constant_(self.bn.bias, 0) + + @torch.no_grad() + def fuse(self): + c, bn = self.conv, self.bn + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d( + w.size(1) * self.conv.groups, w.size(0), w.shape[2:], + stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class NormLinear(torch.nn.Sequential): + def __init__(self, in_features, out_features, bias=True, std=0.02, drop=0.): + super().__init__() + self.bn = nn.BatchNorm1d(in_features) + self.drop = nn.Dropout(drop) + self.linear = nn.Linear(in_features, out_features, bias=bias) + + trunc_normal_(self.linear.weight, std=std) + if self.linear.bias is not None: + nn.init.constant_(self.linear.bias, 0) + + @torch.no_grad() + def fuse(self): + bn, linear = self.bn, self.linear + w = bn.weight / (bn.running_var + bn.eps)**0.5 + b = bn.bias - self.bn.running_mean * \ + self.bn.weight / (bn.running_var + bn.eps)**0.5 + w = linear.weight * w[None, :] + if linear.bias is None: + b = b @ self.linear.weight.T + else: + b = (linear.weight @ b[:, None]).view(-1) + self.linear.bias + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class PatchMerging(torch.nn.Module): + def __init__(self, dim, out_dim): + super().__init__() + hid_dim = int(dim * 4) + self.conv1 = ConvNorm(dim, hid_dim, 1, 1, 0) + self.act = torch.nn.ReLU() + self.conv2 = ConvNorm(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim) + self.se = SqueezeExcite(hid_dim, .25) + self.conv3 = ConvNorm(hid_dim, out_dim, 1, 1, 0) + + def forward(self, x): + x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x)))))) + return x + + +class ResidualDrop(torch.nn.Module): + def __init__(self, m, drop=0.): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand( + x.size(0), 1, 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class ConvMlp(torch.nn.Module): + def __init__(self, ed, h): + super().__init__() + self.pw1 = ConvNorm(ed, h) + self.act = torch.nn.ReLU() + self.pw2 = ConvNorm(h, ed, bn_weight_init=0) + + def forward(self, x): + x = self.pw2(self.act(self.pw1(x))) + return x + + +class CascadedGroupAttention(torch.nn.Module): + attention_bias_cache: Dict[str, torch.Tensor] + + r""" Cascaded Group Attention. + + Args: + dim (int): Number of input channels. + key_dim (int): The dimension for query and key. + num_heads (int): Number of attention heads. + attn_ratio (int): Multiplier for the query dim for value dimension. + resolution (int): Input resolution, correspond to the window size. + kernels (List[int]): The kernel size of the dw conv on query. + """ + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=14, + kernels=(5, 5, 5, 5), + ): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.val_dim = int(attn_ratio * key_dim) + self.attn_ratio = attn_ratio + + qkvs = [] + dws = [] + for i in range(num_heads): + qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.val_dim)) + dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim)) + self.qkvs = torch.nn.ModuleList(qkvs) + self.dws = torch.nn.ModuleList(dws) + self.proj = torch.nn.Sequential( + torch.nn.ReLU(), + ConvNorm(self.val_dim * num_heads, dim, bn_weight_init=0) + ) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False) + self.attention_bias_cache = {} + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.attention_bias_cache: + self.attention_bias_cache = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if torch.jit.is_tracing() or self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.attention_bias_cache[device_key] + + def forward(self, x): + B, C, H, W = x.shape + feats_in = x.chunk(len(self.qkvs), dim=1) + feats_out = [] + feat = feats_in[0] + attn_bias = self.get_attention_biases(x.device) + for head_idx, (qkv, dws) in enumerate(zip(self.qkvs, self.dws)): + if head_idx > 0: + feat = feat + feats_in[head_idx] + feat = qkv(feat) + q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.val_dim], dim=1) + q = dws(q) + q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) + q = q * self.scale + attn = q.transpose(-2, -1) @ k + attn = attn + attn_bias[head_idx] + attn = attn.softmax(dim=-1) + feat = v @ attn.transpose(-2, -1) + feat = feat.view(B, self.val_dim, H, W) + feats_out.append(feat) + x = self.proj(torch.cat(feats_out, 1)) + return x + + +class LocalWindowAttention(torch.nn.Module): + r""" Local Window Attention. + + Args: + dim (int): Number of input channels. + key_dim (int): The dimension for query and key. + num_heads (int): Number of attention heads. + attn_ratio (int): Multiplier for the query dim for value dimension. + resolution (int): Input resolution. + window_resolution (int): Local window resolution. + kernels (List[int]): The kernel size of the dw conv on query. + """ + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=14, + window_resolution=7, + kernels=(5, 5, 5, 5), + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.resolution = resolution + assert window_resolution > 0, 'window_size must be greater than 0' + self.window_resolution = window_resolution + window_resolution = min(window_resolution, resolution) + self.attn = CascadedGroupAttention( + dim, key_dim, num_heads, + attn_ratio=attn_ratio, + resolution=window_resolution, + kernels=kernels, + ) + + def forward(self, x): + H = W = self.resolution + B, C, H_, W_ = x.shape + # Only check this for classifcation models + _assert(H == H_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}') + _assert(W == W_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}') + if H <= self.window_resolution and W <= self.window_resolution: + x = self.attn(x) + else: + x = x.permute(0, 2, 3, 1) + pad_b = (self.window_resolution - H % self.window_resolution) % self.window_resolution + pad_r = (self.window_resolution - W % self.window_resolution) % self.window_resolution + x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_resolution + nW = pW // self.window_resolution + # window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw + x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3) + x = x.reshape(B * nH * nW, self.window_resolution, self.window_resolution, C).permute(0, 3, 1, 2) + x = self.attn(x) + # window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC + x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution, C) + x = x.transpose(2, 3).reshape(B, pH, pW, C) + x = x[:, :H, :W].contiguous() + x = x.permute(0, 3, 1, 2) + return x + + +class EfficientVitBlock(torch.nn.Module): + """ A basic EfficientVit building block. + + Args: + dim (int): Number of input channels. + key_dim (int): Dimension for query and key in the token mixer. + num_heads (int): Number of attention heads. + attn_ratio (int): Multiplier for the query dim for value dimension. + resolution (int): Input resolution. + window_resolution (int): Local window resolution. + kernels (List[int]): The kernel size of the dw conv on query. + """ + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=14, + window_resolution=7, + kernels=[5, 5, 5, 5], + ): + super().__init__() + + self.dw0 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.)) + self.ffn0 = ResidualDrop(ConvMlp(dim, int(dim * 2))) + + self.mixer = ResidualDrop( + LocalWindowAttention( + dim, key_dim, num_heads, + attn_ratio=attn_ratio, + resolution=resolution, + window_resolution=window_resolution, + kernels=kernels, + ) + ) + + self.dw1 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.)) + self.ffn1 = ResidualDrop(ConvMlp(dim, int(dim * 2))) + + def forward(self, x): + return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x))))) + + +class EfficientVitStage(torch.nn.Module): + def __init__( + self, + in_dim, + out_dim, + key_dim, + downsample=('', 1), + num_heads=8, + attn_ratio=4, + resolution=14, + window_resolution=7, + kernels=[5, 5, 5, 5], + depth=1, + ): + super().__init__() + if downsample[0] == 'subsample': + self.resolution = (resolution - 1) // downsample[1] + 1 + down_blocks = [] + down_blocks.append(( + 'res1', + torch.nn.Sequential( + ResidualDrop(ConvNorm(in_dim, in_dim, 3, 1, 1, groups=in_dim)), + ResidualDrop(ConvMlp(in_dim, int(in_dim * 2))), + ) + )) + down_blocks.append(('patchmerge', PatchMerging(in_dim, out_dim))) + down_blocks.append(( + 'res2', + torch.nn.Sequential( + ResidualDrop(ConvNorm(out_dim, out_dim, 3, 1, 1, groups=out_dim)), + ResidualDrop(ConvMlp(out_dim, int(out_dim * 2))), + ) + )) + self.downsample = nn.Sequential(OrderedDict(down_blocks)) + else: + assert in_dim == out_dim + self.downsample = nn.Identity() + self.resolution = resolution + + blocks = [] + for d in range(depth): + blocks.append(EfficientVitBlock(out_dim, key_dim, num_heads, attn_ratio, self.resolution, window_resolution, kernels)) + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + x = self.downsample(x) + x = self.blocks(x) + return x + + +class PatchEmbedding(torch.nn.Sequential): + def __init__(self, in_chans, dim): + super().__init__() + self.add_module('conv1', ConvNorm(in_chans, dim // 8, 3, 2, 1)) + self.add_module('relu1', torch.nn.ReLU()) + self.add_module('conv2', ConvNorm(dim // 8, dim // 4, 3, 2, 1)) + self.add_module('relu2', torch.nn.ReLU()) + self.add_module('conv3', ConvNorm(dim // 4, dim // 2, 3, 2, 1)) + self.add_module('relu3', torch.nn.ReLU()) + self.add_module('conv4', ConvNorm(dim // 2, dim, 3, 2, 1)) + self.patch_size = 16 + + +class EfficientVitMsra(nn.Module): + def __init__( + self, + img_size=224, + in_chans=3, + num_classes=1000, + embed_dim=(64, 128, 192), + key_dim=(16, 16, 16), + depth=(1, 2, 3), + num_heads=(4, 4, 4), + window_size=(7, 7, 7), + kernels=(5, 5, 5, 5), + down_ops=(('', 1), ('subsample', 2), ('subsample', 2)), + global_pool='avg', + drop_rate=0., + ): + super(EfficientVitMsra, self).__init__() + self.grad_checkpointing = False + self.num_classes = num_classes + self.drop_rate = drop_rate + + # Patch embedding + self.patch_embed = PatchEmbedding(in_chans, embed_dim[0]) + stride = self.patch_embed.patch_size + resolution = img_size // self.patch_embed.patch_size + attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))] + + # Build EfficientVit blocks + self.feature_info = [] + stages = [] + pre_ed = embed_dim[0] + for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)): + stage = EfficientVitStage( + in_dim=pre_ed, + out_dim=ed, + key_dim=kd, + downsample=do, + num_heads=nh, + attn_ratio=ar, + resolution=resolution, + window_resolution=wd, + kernels=kernels, + depth=dpth, + ) + pre_ed = ed + if do[0] == 'subsample' and i != 0: + stride *= do[1] + resolution = stage.resolution + stages.append(stage) + self.feature_info += [dict(num_chs=ed, reduction=stride, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) + + if global_pool == 'avg': + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + else: + assert num_classes == 0 + self.global_pool = nn.Identity() + self.num_features = embed_dim[-1] + self.head = NormLinear( + self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity() + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^patch_embed', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.\w+\.(\d+)', None), + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.linear + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + if global_pool == 'avg': + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + else: + assert num_classes == 0 + self.global_pool = nn.Identity() + self.head = NormLinear( + self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + x = self.global_pool(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +# def checkpoint_filter_fn(state_dict, model): +# if 'model' in state_dict.keys(): +# state_dict = state_dict['model'] +# tmp_dict = {} +# out_dict = {} +# target_keys = model.state_dict().keys() +# target_keys = [k for k in target_keys if k.startswith('stages.')] +# +# for k, v in state_dict.items(): +# if 'attention_bias_idxs' in k: +# continue +# k = k.split('.') +# if k[-2] == 'c': +# k[-2] = 'conv' +# if k[-2] == 'l': +# k[-2] = 'linear' +# k = '.'.join(k) +# tmp_dict[k] = v +# +# for k, v in tmp_dict.items(): +# if k.startswith('patch_embed'): +# k = k.split('.') +# k[1] = 'conv' + str(int(k[1]) // 2 + 1) +# k = '.'.join(k) +# elif k.startswith('blocks'): +# kw = '.'.join(k.split('.')[2:]) +# find_kw = [a for a in list(sorted(tmp_dict.keys())) if kw in a] +# idx = find_kw.index(k) +# k = [a for a in target_keys if kw in a][idx] +# out_dict[k] = v +# +# return out_dict + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.conv1.conv', + 'classifier': 'head.linear', + 'fixed_input_size': True, + 'pool_size': (4, 4), + **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + 'efficientvit_m0.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth' + ), + 'efficientvit_m1.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth' + ), + 'efficientvit_m2.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth' + ), + 'efficientvit_m3.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth' + ), + 'efficientvit_m4.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth' + ), + 'efficientvit_m5.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth' + ), +}) + + +def _create_efficientvit_msra(variant, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', (0, 1, 2)) + model = build_model_with_cfg( + EfficientVitMsra, + variant, + pretrained, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs + ) + return model + + +@register_model +def efficientvit_m0(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[64, 128, 192], + depth=[1, 2, 3], + num_heads=[4, 4, 4], + window_size=[7, 7, 7], + kernels=[5, 5, 5, 5] + ) + return _create_efficientvit_msra('efficientvit_m0', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_m1(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[128, 144, 192], + depth=[1, 2, 3], + num_heads=[2, 3, 3], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) + return _create_efficientvit_msra('efficientvit_m1', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_m2(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[128, 192, 224], + depth=[1, 2, 3], + num_heads=[4, 3, 2], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) + return _create_efficientvit_msra('efficientvit_m2', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_m3(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[128, 240, 320], + depth=[1, 2, 3], + num_heads=[4, 3, 4], + window_size=[7, 7, 7], + kernels=[5, 5, 5, 5] + ) + return _create_efficientvit_msra('efficientvit_m3', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_m4(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[128, 256, 384], + depth=[1, 2, 3], + num_heads=[4, 4, 4], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) + return _create_efficientvit_msra('efficientvit_m4', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_m5(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[192, 288, 384], + depth=[1, 3, 4], + num_heads=[3, 3, 4], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) + return _create_efficientvit_msra('efficientvit_m5', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/fastvit.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/fastvit.py new file mode 100644 index 0000000000000000000000000000000000000000..67961880b1a0c0e36ffc6a891535ce5af7aabd21 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/fastvit.py @@ -0,0 +1,1421 @@ +# FastViT for PyTorch +# +# Original implementation and weights from https://github.com/apple/ml-fastvit +# +# For licensing see accompanying LICENSE file at https://github.com/apple/ml-fastvit/tree/main +# Original work is copyright (C) 2023 Apple Inc. All Rights Reserved. +# +import os +from functools import partial +from typing import Tuple, Optional, Union + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \ + ClassifierHead +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs + + +def num_groups(group_size, channels): + if not group_size: # 0 or None + return 1 # normal conv with 1 group + else: + # NOTE group_size == 1 -> depthwise conv + assert channels % group_size == 0 + return channels // group_size + + +class MobileOneBlock(nn.Module): + """MobileOne building block. + + This block has a multi-branched architecture at train-time + and plain-CNN style architecture at inference time + For more details, please refer to our paper: + `An Improved One millisecond Mobile Backbone` - + https://arxiv.org/pdf/2206.04040.pdf + """ + + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + group_size: int = 0, + inference_mode: bool = False, + use_se: bool = False, + use_act: bool = True, + use_scale_branch: bool = True, + num_conv_branches: int = 1, + act_layer: nn.Module = nn.GELU, + ) -> None: + """Construct a MobileOneBlock module. + + Args: + in_chs: Number of channels in the input. + out_chs: Number of channels produced by the block. + kernel_size: Size of the convolution kernel. + stride: Stride size. + dilation: Kernel dilation factor. + group_size: Convolution group size. + inference_mode: If True, instantiates model in inference mode. + use_se: Whether to use SE-ReLU activations. + use_act: Whether to use activation. Default: ``True`` + use_scale_branch: Whether to use scale branch. Default: ``True`` + num_conv_branches: Number of linear conv branches. + """ + super(MobileOneBlock, self).__init__() + self.inference_mode = inference_mode + self.groups = num_groups(group_size, in_chs) + self.stride = stride + self.dilation = dilation + self.kernel_size = kernel_size + self.in_chs = in_chs + self.out_chs = out_chs + self.num_conv_branches = num_conv_branches + + # Check if SE-ReLU is requested + self.se = SqueezeExcite(out_chs, rd_divisor=1) if use_se else nn.Identity() + + if inference_mode: + self.reparam_conv = create_conv2d( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=self.groups, + bias=True, + ) + else: + # Re-parameterizable skip connection + self.reparam_conv = None + + self.identity = ( + nn.BatchNorm2d(num_features=in_chs) + if out_chs == in_chs and stride == 1 + else None + ) + + # Re-parameterizable conv branches + if num_conv_branches > 0: + self.conv_kxk = nn.ModuleList([ + ConvNormAct( + self.in_chs, + self.out_chs, + kernel_size=kernel_size, + stride=self.stride, + groups=self.groups, + apply_act=False, + ) for _ in range(self.num_conv_branches) + ]) + else: + self.conv_kxk = None + + # Re-parameterizable scale branch + self.conv_scale = None + if kernel_size > 1 and use_scale_branch: + self.conv_scale = ConvNormAct( + self.in_chs, + self.out_chs, + kernel_size=1, + stride=self.stride, + groups=self.groups, + apply_act=False + ) + + self.act = act_layer() if use_act else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply forward pass.""" + # Inference mode forward pass. + if self.reparam_conv is not None: + return self.act(self.se(self.reparam_conv(x))) + + # Multi-branched train-time forward pass. + # Identity branch output + identity_out = 0 + if self.identity is not None: + identity_out = self.identity(x) + + # Scale branch output + scale_out = 0 + if self.conv_scale is not None: + scale_out = self.conv_scale(x) + + # Other kxk conv branches + out = scale_out + identity_out + if self.conv_kxk is not None: + for rc in self.conv_kxk: + out += rc(x) + + return self.act(self.se(out)) + + def reparameterize(self): + """Following works like `RepVGG: Making VGG-style ConvNets Great Again` - + https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched + architecture used at training time to obtain a plain CNN-like structure + for inference. + """ + if self.reparam_conv is not None: + return + + kernel, bias = self._get_kernel_bias() + self.reparam_conv = create_conv2d( + in_channels=self.in_chs, + out_channels=self.out_chs, + kernel_size=self.kernel_size, + stride=self.stride, + dilation=self.dilation, + groups=self.groups, + bias=True, + ) + self.reparam_conv.weight.data = kernel + self.reparam_conv.bias.data = bias + + # Delete un-used branches + for name, para in self.named_parameters(): + if 'reparam_conv' in name: + continue + para.detach_() + + self.__delattr__("conv_kxk") + self.__delattr__("conv_scale") + if hasattr(self, "identity"): + self.__delattr__("identity") + + self.inference_mode = True + + def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 + + Returns: + Tuple of (kernel, bias) after fusing branches. + """ + # get weights and bias of scale branch + kernel_scale = 0 + bias_scale = 0 + if self.conv_scale is not None: + kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.kernel_size // 2 + kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) + + # get weights and bias of skip branch + kernel_identity = 0 + bias_identity = 0 + if self.identity is not None: + kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity) + + # get weights and bias of conv branches + kernel_conv = 0 + bias_conv = 0 + if self.conv_kxk is not None: + for ix in range(self.num_conv_branches): + _kernel, _bias = self._fuse_bn_tensor(self.conv_kxk[ix]) + kernel_conv += _kernel + bias_conv += _bias + + kernel_final = kernel_conv + kernel_scale + kernel_identity + bias_final = bias_conv + bias_scale + bias_identity + return kernel_final, bias_final + + def _fuse_bn_tensor( + self, branch: Union[nn.Sequential, nn.BatchNorm2d] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to fuse batchnorm layer with preceeding conv layer. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 + + Args: + branch: Sequence of ops to be fused. + + Returns: + Tuple of (kernel, bias) after fusing batchnorm. + """ + if isinstance(branch, ConvNormAct): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, "id_tensor"): + input_dim = self.in_chs // self.groups + kernel_value = torch.zeros( + (self.in_chs, input_dim, self.kernel_size, self.kernel_size), + dtype=branch.weight.dtype, + device=branch.weight.device, + ) + for i in range(self.in_chs): + kernel_value[ + i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2 + ] = 1 + self.id_tensor = kernel_value + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class ReparamLargeKernelConv(nn.Module): + """Building Block of RepLKNet + + This class defines overparameterized large kernel conv block + introduced in `RepLKNet `_ + + Reference: https://github.com/DingXiaoH/RepLKNet-pytorch + """ + + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int, + group_size: int, + small_kernel: Optional[int] = None, + inference_mode: bool = False, + act_layer: Optional[nn.Module] = None, + ) -> None: + """Construct a ReparamLargeKernelConv module. + + Args: + in_chs: Number of input channels. + out_chs: Number of output channels. + kernel_size: Kernel size of the large kernel conv branch. + stride: Stride size. Default: 1 + group_size: Group size. Default: 1 + small_kernel: Kernel size of small kernel conv branch. + inference_mode: If True, instantiates model in inference mode. Default: ``False`` + act_layer: Activation module. Default: ``nn.GELU`` + """ + super(ReparamLargeKernelConv, self).__init__() + self.stride = stride + self.groups = num_groups(group_size, in_chs) + self.in_chs = in_chs + self.out_chs = out_chs + + self.kernel_size = kernel_size + self.small_kernel = small_kernel + if inference_mode: + self.reparam_conv = create_conv2d( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=1, + groups=self.groups, + bias=True, + ) + else: + self.reparam_conv = None + self.large_conv = ConvNormAct( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=self.stride, + groups=self.groups, + apply_act=False, + ) + if small_kernel is not None: + assert ( + small_kernel <= kernel_size + ), "The kernel size for re-param cannot be larger than the large kernel!" + self.small_conv = ConvNormAct( + in_chs, + out_chs, + kernel_size=small_kernel, + stride=self.stride, + groups=self.groups, + apply_act=False, + ) + # FIXME output of this act was not used in original impl, likely due to bug + self.act = act_layer() if act_layer is not None else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.reparam_conv is not None: + out = self.reparam_conv(x) + else: + out = self.large_conv(x) + if self.small_conv is not None: + out = out + self.small_conv(x) + out = self.act(out) + return out + + def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepLKNet-pytorch + + Returns: + Tuple of (kernel, bias) after fusing branches. + """ + eq_k, eq_b = self._fuse_bn(self.large_conv.conv, self.large_conv.bn) + if hasattr(self, "small_conv"): + small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn) + eq_b += small_b + eq_k += nn.functional.pad( + small_k, [(self.kernel_size - self.small_kernel) // 2] * 4 + ) + return eq_k, eq_b + + def reparameterize(self) -> None: + """ + Following works like `RepVGG: Making VGG-style ConvNets Great Again` - + https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched + architecture used at training time to obtain a plain CNN-like structure + for inference. + """ + eq_k, eq_b = self.get_kernel_bias() + self.reparam_conv = create_conv2d( + self.in_chs, + self.out_chs, + kernel_size=self.kernel_size, + stride=self.stride, + groups=self.groups, + bias=True, + ) + + self.reparam_conv.weight.data = eq_k + self.reparam_conv.bias.data = eq_b + self.__delattr__("large_conv") + if hasattr(self, "small_conv"): + self.__delattr__("small_conv") + + @staticmethod + def _fuse_bn( + conv: torch.Tensor, bn: nn.BatchNorm2d + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to fuse batchnorm layer with conv layer. + + Args: + conv: Convolutional kernel weights. + bn: Batchnorm 2d layer. + + Returns: + Tuple of (kernel, bias) after fusing batchnorm. + """ + kernel = conv.weight + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +def convolutional_stem( + in_chs: int, + out_chs: int, + act_layer: nn.Module = nn.GELU, + inference_mode: bool = False +) -> nn.Sequential: + """Build convolutional stem with MobileOne blocks. + + Args: + in_chs: Number of input channels. + out_chs: Number of output channels. + inference_mode: Flag to instantiate model in inference mode. Default: ``False`` + + Returns: + nn.Sequential object with stem elements. + """ + return nn.Sequential( + MobileOneBlock( + in_chs=in_chs, + out_chs=out_chs, + kernel_size=3, + stride=2, + act_layer=act_layer, + inference_mode=inference_mode, + ), + MobileOneBlock( + in_chs=out_chs, + out_chs=out_chs, + kernel_size=3, + stride=2, + group_size=1, + act_layer=act_layer, + inference_mode=inference_mode, + ), + MobileOneBlock( + in_chs=out_chs, + out_chs=out_chs, + kernel_size=1, + stride=1, + act_layer=act_layer, + inference_mode=inference_mode, + ), + ) + + +class Attention(nn.Module): + """Multi-headed Self Attention module. + + Source modified from: + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim: int, + head_dim: int = 32, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + """Build MHSA module that can handle 3D or 4D input tensors. + + Args: + dim: Number of embedding dimensions. + head_dim: Number of hidden dimensions per head. Default: ``32`` + qkv_bias: Use bias or not. Default: ``False`` + attn_drop: Dropout rate for attention tensor. + proj_drop: Dropout rate for projection tensor. + """ + super().__init__() + assert dim % head_dim == 0, "dim should be divisible by head_dim" + self.head_dim = head_dim + self.num_heads = dim // head_dim + self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + N = H * W + x = x.flatten(2).transpose(-2, -1) # (B, N, C) + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + if self.fused_attn: + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + x = x.transpose(-2, -1).reshape(B, C, H, W) + + return x + + +class PatchEmbed(nn.Module): + """Convolutional patch embedding layer.""" + + def __init__( + self, + patch_size: int, + stride: int, + in_chs: int, + embed_dim: int, + act_layer: nn.Module = nn.GELU, + lkc_use_act: bool = False, + inference_mode: bool = False, + ) -> None: + """Build patch embedding layer. + + Args: + patch_size: Patch size for embedding computation. + stride: Stride for convolutional embedding layer. + in_chs: Number of channels of input tensor. + embed_dim: Number of embedding dimensions. + inference_mode: Flag to instantiate model in inference mode. Default: ``False`` + """ + super().__init__() + self.proj = nn.Sequential( + ReparamLargeKernelConv( + in_chs=in_chs, + out_chs=embed_dim, + kernel_size=patch_size, + stride=stride, + group_size=1, + small_kernel=3, + inference_mode=inference_mode, + act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act + ), + MobileOneBlock( + in_chs=embed_dim, + out_chs=embed_dim, + kernel_size=1, + stride=1, + act_layer=act_layer, + inference_mode=inference_mode, + ) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x + + +class LayerScale2d(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim, 1, 1)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class RepMixer(nn.Module): + """Reparameterizable token mixer. + + For more details, please refer to our paper: + `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization `_ + """ + + def __init__( + self, + dim, + kernel_size=3, + layer_scale_init_value=1e-5, + inference_mode: bool = False, + ): + """Build RepMixer Module. + + Args: + dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. + kernel_size: Kernel size for spatial mixing. Default: 3 + layer_scale_init_value: Initial value for layer scale. Default: 1e-5 + inference_mode: If True, instantiates model in inference mode. Default: ``False`` + """ + super().__init__() + self.dim = dim + self.kernel_size = kernel_size + self.inference_mode = inference_mode + + if inference_mode: + self.reparam_conv = nn.Conv2d( + self.dim, + self.dim, + kernel_size=self.kernel_size, + stride=1, + padding=self.kernel_size // 2, + groups=self.dim, + bias=True, + ) + else: + self.reparam_conv = None + self.norm = MobileOneBlock( + dim, + dim, + kernel_size, + group_size=1, + use_act=False, + use_scale_branch=False, + num_conv_branches=0, + ) + self.mixer = MobileOneBlock( + dim, + dim, + kernel_size, + group_size=1, + use_act=False, + ) + if layer_scale_init_value is not None: + self.layer_scale = LayerScale2d(dim, layer_scale_init_value) + else: + self.layer_scale = nn.Identity + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.reparam_conv is not None: + x = self.reparam_conv(x) + else: + x = x + self.layer_scale(self.mixer(x) - self.norm(x)) + return x + + def reparameterize(self) -> None: + """Reparameterize mixer and norm into a single + convolutional layer for efficient inference. + """ + if self.inference_mode: + return + + self.mixer.reparameterize() + self.norm.reparameterize() + + if isinstance(self.layer_scale, LayerScale2d): + w = self.mixer.id_tensor + self.layer_scale.gamma.unsqueeze(-1) * ( + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight + ) + b = torch.squeeze(self.layer_scale.gamma) * ( + self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias + ) + else: + w = ( + self.mixer.id_tensor + + self.mixer.reparam_conv.weight + - self.norm.reparam_conv.weight + ) + b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias + + self.reparam_conv = create_conv2d( + self.dim, + self.dim, + kernel_size=self.kernel_size, + stride=1, + groups=self.dim, + bias=True, + ) + self.reparam_conv.weight.data = w + self.reparam_conv.bias.data = b + + for name, para in self.named_parameters(): + if 'reparam_conv' in name: + continue + para.detach_() + self.__delattr__("mixer") + self.__delattr__("norm") + self.__delattr__("layer_scale") + + +class ConvMlp(nn.Module): + """Convolutional FFN Module.""" + + def __init__( + self, + in_chs: int, + hidden_channels: Optional[int] = None, + out_chs: Optional[int] = None, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, + ) -> None: + """Build convolutional FFN module. + + Args: + in_chs: Number of input channels. + hidden_channels: Number of channels after expansion. Default: None + out_chs: Number of output channels. Default: None + act_layer: Activation layer. Default: ``GELU`` + drop: Dropout rate. Default: ``0.0``. + """ + super().__init__() + out_chs = out_chs or in_chs + hidden_channels = hidden_channels or in_chs + self.conv = ConvNormAct( + in_chs, + out_chs, + kernel_size=7, + groups=in_chs, + apply_act=False, + ) + self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1) + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1) + self.drop = nn.Dropout(drop) + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class RepConditionalPosEnc(nn.Module): + """Implementation of conditional positional encoding. + + For more details refer to paper: + `Conditional Positional Encodings for Vision Transformers `_ + + In our implementation, we can reparameterize this module to eliminate a skip connection. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + spatial_shape: Union[int, Tuple[int, int]] = (7, 7), + inference_mode=False, + ) -> None: + """Build reparameterizable conditional positional encoding + + Args: + dim: Number of input channels. + dim_out: Number of embedding dimensions. Default: 768 + spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) + inference_mode: Flag to instantiate block in inference mode. Default: ``False`` + """ + super(RepConditionalPosEnc, self).__init__() + if isinstance(spatial_shape, int): + spatial_shape = tuple([spatial_shape] * 2) + assert isinstance(spatial_shape, Tuple), ( + f'"spatial_shape" must by a sequence or int, ' + f"get {type(spatial_shape)} instead." + ) + assert len(spatial_shape) == 2, ( + f'Length of "spatial_shape" should be 2, ' + f"got {len(spatial_shape)} instead." + ) + + self.spatial_shape = spatial_shape + self.dim = dim + self.dim_out = dim_out or dim + self.groups = dim + + if inference_mode: + self.reparam_conv = nn.Conv2d( + self.dim, + self.dim_out, + kernel_size=self.spatial_shape, + stride=1, + padding=spatial_shape[0] // 2, + groups=self.groups, + bias=True, + ) + else: + self.reparam_conv = None + self.pos_enc = nn.Conv2d( + self.dim, + self.dim_out, + spatial_shape, + 1, + int(spatial_shape[0] // 2), + groups=self.groups, + bias=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.reparam_conv is not None: + x = self.reparam_conv(x) + else: + x = self.pos_enc(x) + x + return x + + def reparameterize(self) -> None: + # Build equivalent Id tensor + input_dim = self.dim // self.groups + kernel_value = torch.zeros( + ( + self.dim, + input_dim, + self.spatial_shape[0], + self.spatial_shape[1], + ), + dtype=self.pos_enc.weight.dtype, + device=self.pos_enc.weight.device, + ) + for i in range(self.dim): + kernel_value[ + i, + i % input_dim, + self.spatial_shape[0] // 2, + self.spatial_shape[1] // 2, + ] = 1 + id_tensor = kernel_value + + # Reparameterize Id tensor and conv + w_final = id_tensor + self.pos_enc.weight + b_final = self.pos_enc.bias + + # Introduce reparam conv + self.reparam_conv = nn.Conv2d( + self.dim, + self.dim_out, + kernel_size=self.spatial_shape, + stride=1, + padding=int(self.spatial_shape[0] // 2), + groups=self.groups, + bias=True, + ) + self.reparam_conv.weight.data = w_final + self.reparam_conv.bias.data = b_final + + for name, para in self.named_parameters(): + if 'reparam_conv' in name: + continue + para.detach_() + self.__delattr__("pos_enc") + + +class RepMixerBlock(nn.Module): + """Implementation of Metaformer block with RepMixer as token mixer. + + For more details on Metaformer structure, please refer to: + `MetaFormer Is Actually What You Need for Vision `_ + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + proj_drop: float = 0.0, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-5, + inference_mode: bool = False, + ): + """Build RepMixer Block. + + Args: + dim: Number of embedding dimensions. + kernel_size: Kernel size for repmixer. Default: 3 + mlp_ratio: MLP expansion ratio. Default: 4.0 + act_layer: Activation layer. Default: ``nn.GELU`` + proj_drop: Dropout rate. Default: 0.0 + drop_path: Drop path rate. Default: 0.0 + layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 + inference_mode: Flag to instantiate block in inference mode. Default: ``False`` + """ + + super().__init__() + + self.token_mixer = RepMixer( + dim, + kernel_size=kernel_size, + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + ) + + self.mlp = ConvMlp( + in_chs=dim, + hidden_channels=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + if layer_scale_init_value is not None: + self.layer_scale = LayerScale2d(dim, layer_scale_init_value) + else: + self.layer_scale = nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + x = self.token_mixer(x) + x = x + self.drop_path(self.layer_scale(self.mlp(x))) + return x + + +class AttentionBlock(nn.Module): + """Implementation of metaformer block with MHSA as token mixer. + + For more details on Metaformer structure, please refer to: + `MetaFormer Is Actually What You Need for Vision `_ + """ + + def __init__( + self, + dim: int, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.BatchNorm2d, + proj_drop: float = 0.0, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-5, + ): + """Build Attention Block. + + Args: + dim: Number of embedding dimensions. + mlp_ratio: MLP expansion ratio. Default: 4.0 + act_layer: Activation layer. Default: ``nn.GELU`` + norm_layer: Normalization layer. Default: ``nn.BatchNorm2d`` + proj_drop: Dropout rate. Default: 0.0 + drop_path: Drop path rate. Default: 0.0 + layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 + """ + + super().__init__() + + self.norm = norm_layer(dim) + self.token_mixer = Attention(dim=dim) + if layer_scale_init_value is not None: + self.layer_scale_1 = LayerScale2d(dim, layer_scale_init_value) + else: + self.layer_scale_1 = nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.mlp = ConvMlp( + in_chs=dim, + hidden_channels=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + if layer_scale_init_value is not None: + self.layer_scale_2 = LayerScale2d(dim, layer_scale_init_value) + else: + self.layer_scale_2 = nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + x = x + self.drop_path1(self.layer_scale_1(self.token_mixer(self.norm(x)))) + x = x + self.drop_path2(self.layer_scale_2(self.mlp(x))) + return x + + +class FastVitStage(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + depth: int, + token_mixer_type: str, + downsample: bool = True, + down_patch_size: int = 7, + down_stride: int = 2, + pos_emb_layer: Optional[nn.Module] = None, + kernel_size: int = 3, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.BatchNorm2d, + proj_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + layer_scale_init_value: Optional[float] = 1e-5, + lkc_use_act=False, + inference_mode=False, + ): + """FastViT stage. + + Args: + dim: Number of embedding dimensions. + depth: Number of blocks in stage + token_mixer_type: Token mixer type. + kernel_size: Kernel size for repmixer. + mlp_ratio: MLP expansion ratio. + act_layer: Activation layer. + norm_layer: Normalization layer. + proj_drop_rate: Dropout rate. + drop_path_rate: Drop path rate. + layer_scale_init_value: Layer scale value at initialization. + inference_mode: Flag to instantiate block in inference mode. + """ + super().__init__() + self.grad_checkpointing = False + + if downsample: + self.downsample = PatchEmbed( + patch_size=down_patch_size, + stride=down_stride, + in_chs=dim, + embed_dim=dim_out, + act_layer=act_layer, + lkc_use_act=lkc_use_act, + inference_mode=inference_mode, + ) + else: + assert dim == dim_out + self.downsample = nn.Identity() + + if pos_emb_layer is not None: + self.pos_emb = pos_emb_layer(dim_out, inference_mode=inference_mode) + else: + self.pos_emb = nn.Identity() + + blocks = [] + for block_idx in range(depth): + if token_mixer_type == "repmixer": + blocks.append(RepMixerBlock( + dim_out, + kernel_size=kernel_size, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + proj_drop=proj_drop_rate, + drop_path=drop_path_rate[block_idx], + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + )) + elif token_mixer_type == "attention": + blocks.append(AttentionBlock( + dim_out, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + proj_drop=proj_drop_rate, + drop_path=drop_path_rate[block_idx], + layer_scale_init_value=layer_scale_init_value, + )) + else: + raise ValueError( + "Token mixer type: {} not supported".format(token_mixer_type) + ) + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + x = self.downsample(x) + x = self.pos_emb(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class FastVit(nn.Module): + fork_feat: torch.jit.Final[bool] + + """ + This class implements `FastViT architecture `_ + """ + + def __init__( + self, + in_chans: int = 3, + layers: Tuple[int, ...] = (2, 2, 6, 2), + token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"), + embed_dims: Tuple[int, ...] = (64, 128, 256, 512), + mlp_ratios: Tuple[float, ...] = (4,) * 4, + downsamples: Tuple[bool, ...] = (False, True, True, True), + repmixer_kernel_size: int = 3, + num_classes: int = 1000, + pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4, + down_patch_size: int = 7, + down_stride: int = 2, + drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + layer_scale_init_value: float = 1e-5, + fork_feat: bool = False, + cls_ratio: float = 2.0, + global_pool: str = 'avg', + norm_layer: nn.Module = nn.BatchNorm2d, + act_layer: nn.Module = nn.GELU, + lkc_use_act: bool = False, + inference_mode: bool = False, + ) -> None: + super().__init__() + self.num_classes = 0 if fork_feat else num_classes + self.fork_feat = fork_feat + self.global_pool = global_pool + self.feature_info = [] + + # Convolutional stem + self.stem = convolutional_stem( + in_chans, + embed_dims[0], + act_layer, + inference_mode, + ) + + # Build the main stages of the network architecture + prev_dim = embed_dims[0] + scale = 1 + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] + stages = [] + for i in range(len(layers)): + downsample = downsamples[i] or prev_dim != embed_dims[i] + stage = FastVitStage( + dim=prev_dim, + dim_out=embed_dims[i], + depth=layers[i], + downsample=downsample, + down_patch_size=down_patch_size, + down_stride=down_stride, + pos_emb_layer=pos_embs[i], + token_mixer_type=token_mixers[i], + kernel_size=repmixer_kernel_size, + mlp_ratio=mlp_ratios[i], + act_layer=act_layer, + norm_layer=norm_layer, + proj_drop_rate=proj_drop_rate, + drop_path_rate=dpr[i], + layer_scale_init_value=layer_scale_init_value, + lkc_use_act=lkc_use_act, + inference_mode=inference_mode, + ) + stages.append(stage) + prev_dim = embed_dims[i] + if downsample: + scale *= 2 + self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) + self.num_features = prev_dim + + # For segmentation and detection, extract intermediate output + if self.fork_feat: + # Add a norm layer for each output. self.stages is slightly different than self.network + # in the original code, the PatchEmbed layer is part of self.stages in this code where + # it was part of self.network in the original code. So we do not need to skip out indices. + self.out_indices = [0, 1, 2, 3] + for i_emb, i_layer in enumerate(self.out_indices): + if i_emb == 0 and os.environ.get("FORK_LAST3", None): + """For RetinaNet, `start_level=1`. The first norm layer will not used. + cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...` + """ + layer = nn.Identity() + else: + layer = norm_layer(embed_dims[i_emb]) + layer_name = f"norm{i_layer}" + self.add_module(layer_name, layer) + else: + # Classifier head + self.num_features = final_features = int(embed_dims[-1] * cls_ratio) + self.final_conv = MobileOneBlock( + in_chs=embed_dims[-1], + out_chs=final_features, + kernel_size=3, + stride=1, + group_size=1, + inference_mode=inference_mode, + use_se=True, + act_layer=act_layer, + num_conv_branches=1, + ) + self.head = ClassifierHead( + final_features, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module) -> None: + """Init. for classification""" + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + return set() + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', # stem and embed + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+).pos_emb', (0,)), + (r'^stages\.(\d+)\.\w+\.(\d+)', None), + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + # input embedding + x = self.stem(x) + outs = [] + for idx, block in enumerate(self.stages): + x = block(x) + if self.fork_feat: + if idx in self.out_indices: + norm_layer = getattr(self, f"norm{idx}") + x_out = norm_layer(x) + outs.append(x_out) + if self.fork_feat: + # output the features of four stages for dense prediction + return outs + x = self.final_conv(x) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False): + return self.head(x, pre_logits=True) if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + if self.fork_feat: + return x + x = self.forward_head(x) + return x + + +def _cfg(url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 256, 256), + "pool_size": (8, 8), + "crop_pct": 0.9, + "interpolation": "bicubic", + "mean": IMAGENET_DEFAULT_MEAN, + "std": IMAGENET_DEFAULT_STD, + 'first_conv': ('stem.0.conv_kxk.0.conv', 'stem.0.conv_scale.conv'), + "classifier": "head.fc", + **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + "fastvit_t8.apple_in1k": _cfg( + hf_hub_id='timm/'), + "fastvit_t12.apple_in1k": _cfg( + hf_hub_id='timm/'), + + "fastvit_s12.apple_in1k": _cfg( + hf_hub_id='timm/'), + "fastvit_sa12.apple_in1k": _cfg( + hf_hub_id='timm/'), + "fastvit_sa24.apple_in1k": _cfg( + hf_hub_id='timm/'), + "fastvit_sa36.apple_in1k": _cfg( + hf_hub_id='timm/'), + + "fastvit_ma36.apple_in1k": _cfg( + hf_hub_id='timm/', + crop_pct=0.95 + ), + + "fastvit_t8.apple_dist_in1k": _cfg( + hf_hub_id='timm/'), + "fastvit_t12.apple_dist_in1k": _cfg( + hf_hub_id='timm/'), + + "fastvit_s12.apple_dist_in1k": _cfg( + hf_hub_id='timm/',), + "fastvit_sa12.apple_dist_in1k": _cfg( + hf_hub_id='timm/',), + "fastvit_sa24.apple_dist_in1k": _cfg( + hf_hub_id='timm/',), + "fastvit_sa36.apple_dist_in1k": _cfg( + hf_hub_id='timm/',), + + "fastvit_ma36.apple_dist_in1k": _cfg( + hf_hub_id='timm/', + crop_pct=0.95 + ), +}) + + +def _create_fastvit(variant, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) + model = build_model_with_cfg( + FastVit, + variant, + pretrained, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs + ) + return model + + +@register_model +def fastvit_t8(pretrained=False, **kwargs): + """Instantiate FastViT-T8 model variant.""" + model_args = dict( + layers=(2, 2, 4, 2), + embed_dims=(48, 96, 192, 384), + mlp_ratios=(3, 3, 3, 3), + token_mixers=("repmixer", "repmixer", "repmixer", "repmixer") + ) + return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_t12(pretrained=False, **kwargs): + """Instantiate FastViT-T12 model variant.""" + model_args = dict( + layers=(2, 2, 6, 2), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(3, 3, 3, 3), + token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"), + ) + return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_s12(pretrained=False, **kwargs): + """Instantiate FastViT-S12 model variant.""" + model_args = dict( + layers=(2, 2, 6, 2), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(4, 4, 4, 4), + token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"), + ) + return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_sa12(pretrained=False, **kwargs): + """Instantiate FastViT-SA12 model variant.""" + model_args = dict( + layers=(2, 2, 6, 2), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(4, 4, 4, 4), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + ) + return _create_fastvit('fastvit_sa12', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_sa24(pretrained=False, **kwargs): + """Instantiate FastViT-SA24 model variant.""" + model_args = dict( + layers=(4, 4, 12, 4), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(4, 4, 4, 4), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + ) + return _create_fastvit('fastvit_sa24', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_sa36(pretrained=False, **kwargs): + """Instantiate FastViT-SA36 model variant.""" + model_args = dict( + layers=(6, 6, 18, 6), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(4, 4, 4, 4), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + ) + return _create_fastvit('fastvit_sa36', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_ma36(pretrained=False, **kwargs): + """Instantiate FastViT-MA36 model variant.""" + model_args = dict( + layers=(6, 6, 18, 6), + embed_dims=(76, 152, 304, 608), + mlp_ratios=(4, 4, 4, 4), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention") + ) + return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/fx_features.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/fx_features.py new file mode 100644 index 0000000000000000000000000000000000000000..0ff3a18b05844d0d318f6853500988cb1ff29624 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/fx_features.py @@ -0,0 +1,4 @@ +from ._features_fx import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/hardcorenas.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/hardcorenas.py new file mode 100644 index 0000000000000000000000000000000000000000..459c1a3db845d2dc8b16c27521397687340ffe98 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/hardcorenas.py @@ -0,0 +1,156 @@ +from functools import partial + +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from ._builder import build_model_with_cfg +from ._builder import pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite +from ._efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels +from ._registry import register_model, generate_default_cfgs +from .mobilenetv3 import MobileNetV3, MobileNetV3Features + +__all__ = [] # model_registry will add each entrypoint fn to this + + +def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs): + """Creates a hardcorenas model + + Ref impl: https://github.com/Alibaba-MIIL/HardCoReNAS + Paper: https://arxiv.org/abs/2102.11646 + + """ + num_features = 1280 + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=32, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_layer=se_layer, + **kwargs, + ) + + features_only = False + model_cls = MobileNetV3 + kwargs_filter = None + if model_kwargs.pop('features_only', False): + features_only = True + kwargs_filter = ('num_classes', 'num_features', 'global_pool', 'head_conv', 'head_bias', 'global_pool') + model_cls = MobileNetV3Features + model = build_model_with_cfg( + model_cls, + variant, + pretrained, + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **model_kwargs, + ) + if features_only: + model.default_cfg = pretrained_cfg_for_features(model.default_cfg) + return model + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'hardcorenas_a.miil_green_in1k': _cfg(hf_hub_id='timm/'), + 'hardcorenas_b.miil_green_in1k': _cfg(hf_hub_id='timm/'), + 'hardcorenas_c.miil_green_in1k': _cfg(hf_hub_id='timm/'), + 'hardcorenas_d.miil_green_in1k': _cfg(hf_hub_id='timm/'), + 'hardcorenas_e.miil_green_in1k': _cfg(hf_hub_id='timm/'), + 'hardcorenas_f.miil_green_in1k': _cfg(hf_hub_id='timm/'), +}) + + +@register_model +def hardcorenas_a(pretrained=False, **kwargs) -> MobileNetV3: + """ hardcorenas_A """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e6_c40_nre_se0.25'], + ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25'], + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_a', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_b(pretrained=False, **kwargs) -> MobileNetV3: + """ hardcorenas_B """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], + ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25', 'ir_r1_k3_s1_e3_c24_nre'], + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre'], + ['ir_r1_k5_s2_e3_c80', 'ir_r1_k5_s1_e3_c80', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'], + ['ir_r1_k5_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'], + ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_b', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_c(pretrained=False, **kwargs) -> MobileNetV3: + """ hardcorenas_C """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', + 'ir_r1_k5_s1_e3_c40_nre'], + ['ir_r1_k5_s2_e4_c80', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'], + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'], + ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_c', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_d(pretrained=False, **kwargs) -> MobileNetV3: + """ hardcorenas_D """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e3_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k3_s1_e3_c40_nre_se0.25'], + ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', + 'ir_r1_k3_s1_e3_c80_se0.25'], + ['ir_r1_k3_s1_e4_c112_se0.25', 'ir_r1_k5_s1_e4_c112_se0.25', 'ir_r1_k3_s1_e3_c112_se0.25', + 'ir_r1_k5_s1_e3_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_d', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_e(pretrained=False, **kwargs) -> MobileNetV3: + """ hardcorenas_E """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', + 'ir_r1_k3_s1_e3_c40_nre_se0.25'], ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e6_c80_se0.25'], + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', + 'ir_r1_k5_s1_e3_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_e', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_f(pretrained=False, **kwargs) -> MobileNetV3: + """ hardcorenas_F """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e6_c40_nre_se0.25'], + ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', + 'ir_r1_k3_s1_e3_c80_se0.25'], + ['ir_r1_k3_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', + 'ir_r1_k3_s1_e3_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25', + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_f', arch_def=arch_def, **kwargs) + return model diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/helpers.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..6bc82eb81ed9337dd2cac6e0a5ab9a44ad70834b --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/helpers.py @@ -0,0 +1,7 @@ +from ._builder import * +from ._helpers import * +from ._manipulate import * +from ._prune import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/hgnet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/hgnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d7f38af646f27f604e6ddf954743df67e104244b --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/hgnet.py @@ -0,0 +1,744 @@ +""" PP-HGNet (V1 & V2) + +Reference: +https://github.com/PaddlePaddle/PaddleClas/blob/develop/docs/zh_CN/models/ImageNet1k/PP-HGNetV2.md +The Paddle Implement of PP-HGNet (https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/docs/en/models/PP-HGNet_en.md) +PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py +PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d +from ._builder import build_model_with_cfg +from ._registry import register_model, generate_default_cfgs + +__all__ = ['HighPerfGpuNet'] + + +class LearnableAffineBlock(nn.Module): + def __init__( + self, + scale_value=1.0, + bias_value=0.0 + ): + super().__init__() + self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True) + self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True) + + def forward(self, x): + return self.scale * x + self.bias + + +class ConvBNAct(nn.Module): + def __init__( + self, + in_chs, + out_chs, + kernel_size, + stride=1, + groups=1, + padding='', + use_act=True, + use_lab=False + ): + super().__init__() + self.use_act = use_act + self.use_lab = use_lab + self.conv = create_conv2d( + in_chs, + out_chs, + kernel_size, + stride=stride, + padding=padding, + groups=groups, + ) + self.bn = nn.BatchNorm2d(out_chs) + if self.use_act: + self.act = nn.ReLU() + else: + self.act = nn.Identity() + if self.use_act and self.use_lab: + self.lab = LearnableAffineBlock() + else: + self.lab = nn.Identity() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.act(x) + x = self.lab(x) + return x + + +class LightConvBNAct(nn.Module): + def __init__( + self, + in_chs, + out_chs, + kernel_size, + groups=1, + use_lab=False + ): + super().__init__() + self.conv1 = ConvBNAct( + in_chs, + out_chs, + kernel_size=1, + use_act=False, + use_lab=use_lab, + ) + self.conv2 = ConvBNAct( + out_chs, + out_chs, + kernel_size=kernel_size, + groups=out_chs, + use_act=True, + use_lab=use_lab, + ) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class EseModule(nn.Module): + def __init__(self, chs): + super().__init__() + self.conv = nn.Conv2d( + chs, + chs, + kernel_size=1, + stride=1, + padding=0, + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + identity = x + x = x.mean((2, 3), keepdim=True) + x = self.conv(x) + x = self.sigmoid(x) + return torch.mul(identity, x) + + +class StemV1(nn.Module): + # for PP-HGNet + def __init__(self, stem_chs): + super().__init__() + self.stem = nn.Sequential(*[ + ConvBNAct( + stem_chs[i], + stem_chs[i + 1], + kernel_size=3, + stride=2 if i == 0 else 1) for i in range( + len(stem_chs) - 1) + ]) + self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.stem(x) + x = self.pool(x) + return x + + +class StemV2(nn.Module): + # for PP-HGNetv2 + def __init__(self, in_chs, mid_chs, out_chs, use_lab=False): + super().__init__() + self.stem1 = ConvBNAct( + in_chs, + mid_chs, + kernel_size=3, + stride=2, + use_lab=use_lab, + ) + self.stem2a = ConvBNAct( + mid_chs, + mid_chs // 2, + kernel_size=2, + stride=1, + use_lab=use_lab, + ) + self.stem2b = ConvBNAct( + mid_chs // 2, + mid_chs, + kernel_size=2, + stride=1, + use_lab=use_lab, + ) + self.stem3 = ConvBNAct( + mid_chs * 2, + mid_chs, + kernel_size=3, + stride=2, + use_lab=use_lab, + ) + self.stem4 = ConvBNAct( + mid_chs, + out_chs, + kernel_size=1, + stride=1, + use_lab=use_lab, + ) + self.pool = nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True) + + def forward(self, x): + x = self.stem1(x) + x = F.pad(x, (0, 1, 0, 1)) + x2 = self.stem2a(x) + x2 = F.pad(x2, (0, 1, 0, 1)) + x2 = self.stem2b(x2) + x1 = self.pool(x) + x = torch.cat([x1, x2], dim=1) + x = self.stem3(x) + x = self.stem4(x) + return x + + +class HighPerfGpuBlock(nn.Module): + def __init__( + self, + in_chs, + mid_chs, + out_chs, + layer_num, + kernel_size=3, + residual=False, + light_block=False, + use_lab=False, + agg='ese', + drop_path=0., + ): + super().__init__() + self.residual = residual + + self.layers = nn.ModuleList() + for i in range(layer_num): + if light_block: + self.layers.append( + LightConvBNAct( + in_chs if i == 0 else mid_chs, + mid_chs, + kernel_size=kernel_size, + use_lab=use_lab, + ) + ) + else: + self.layers.append( + ConvBNAct( + in_chs if i == 0 else mid_chs, + mid_chs, + kernel_size=kernel_size, + stride=1, + use_lab=use_lab, + ) + ) + + # feature aggregation + total_chs = in_chs + layer_num * mid_chs + if agg == 'se': + aggregation_squeeze_conv = ConvBNAct( + total_chs, + out_chs // 2, + kernel_size=1, + stride=1, + use_lab=use_lab, + ) + aggregation_excitation_conv = ConvBNAct( + out_chs // 2, + out_chs, + kernel_size=1, + stride=1, + use_lab=use_lab, + ) + self.aggregation = nn.Sequential( + aggregation_squeeze_conv, + aggregation_excitation_conv, + ) + else: + aggregation_conv = ConvBNAct( + total_chs, + out_chs, + kernel_size=1, + stride=1, + use_lab=use_lab, + ) + att = EseModule(out_chs) + self.aggregation = nn.Sequential( + aggregation_conv, + att, + ) + + self.drop_path = DropPath(drop_path) if drop_path else nn.Identity() + + def forward(self, x): + identity = x + output = [x] + for layer in self.layers: + x = layer(x) + output.append(x) + x = torch.cat(output, dim=1) + x = self.aggregation(x) + if self.residual: + x = self.drop_path(x) + identity + return x + + +class HighPerfGpuStage(nn.Module): + def __init__( + self, + in_chs, + mid_chs, + out_chs, + block_num, + layer_num, + downsample=True, + stride=2, + light_block=False, + kernel_size=3, + use_lab=False, + agg='ese', + drop_path=0., + ): + super().__init__() + self.downsample = downsample + if downsample: + self.downsample = ConvBNAct( + in_chs, + in_chs, + kernel_size=3, + stride=stride, + groups=in_chs, + use_act=False, + use_lab=use_lab, + ) + else: + self.downsample = nn.Identity() + + blocks_list = [] + for i in range(block_num): + blocks_list.append( + HighPerfGpuBlock( + in_chs if i == 0 else out_chs, + mid_chs, + out_chs, + layer_num, + residual=False if i == 0 else True, + kernel_size=kernel_size, + light_block=light_block, + use_lab=use_lab, + agg=agg, + drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path, + ) + ) + self.blocks = nn.Sequential(*blocks_list) + + def forward(self, x): + x = self.downsample(x) + x = self.blocks(x) + return x + + +class ClassifierHead(nn.Module): + def __init__( + self, + num_features, + num_classes, + pool_type='avg', + drop_rate=0., + use_last_conv=True, + class_expand=2048, + use_lab=False + ): + super(ClassifierHead, self).__init__() + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=False, input_fmt='NCHW') + if use_last_conv: + last_conv = nn.Conv2d( + num_features, + class_expand, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + act = nn.ReLU() + if use_lab: + lab = LearnableAffineBlock() + self.last_conv = nn.Sequential(last_conv, act, lab) + else: + self.last_conv = nn.Sequential(last_conv, act) + else: + self.last_conv = nn.Indentity() + + if drop_rate > 0: + self.dropout = nn.Dropout(drop_rate) + else: + self.dropout = nn.Identity() + + self.flatten = nn.Flatten() + self.fc = nn.Linear(class_expand if use_last_conv else num_features, num_classes) + + def forward(self, x, pre_logits: bool = False): + x = self.global_pool(x) + x = self.last_conv(x) + x = self.dropout(x) + x = self.flatten(x) + if pre_logits: + return x + x = self.fc(x) + return x + + +class HighPerfGpuNet(nn.Module): + + def __init__( + self, + cfg, + in_chans=3, + num_classes=1000, + global_pool='avg', + use_last_conv=True, + class_expand=2048, + drop_rate=0., + drop_path_rate=0., + use_lab=False, + **kwargs, + ): + super(HighPerfGpuNet, self).__init__() + stem_type = cfg["stem_type"] + stem_chs = cfg["stem_chs"] + stages_cfg = [cfg["stage1"], cfg["stage2"], cfg["stage3"], cfg["stage4"]] + self.num_classes = num_classes + self.drop_rate = drop_rate + self.use_last_conv = use_last_conv + self.class_expand = class_expand + self.use_lab = use_lab + + assert stem_type in ['v1', 'v2'] + if stem_type == 'v2': + self.stem = StemV2( + in_chs=in_chans, + mid_chs=stem_chs[0], + out_chs=stem_chs[1], + use_lab=use_lab) + else: + self.stem = StemV1([in_chans] + stem_chs) + + current_stride = 4 + + stages = [] + self.feature_info = [] + block_depths = [c[3] for c in stages_cfg] + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(block_depths)).split(block_depths)] + for i, stage_config in enumerate(stages_cfg): + in_chs, mid_chs, out_chs, block_num, downsample, light_block, kernel_size, layer_num = stage_config + stages += [HighPerfGpuStage( + in_chs=in_chs, + mid_chs=mid_chs, + out_chs=out_chs, + block_num=block_num, + layer_num=layer_num, + downsample=downsample, + light_block=light_block, + kernel_size=kernel_size, + use_lab=use_lab, + agg='ese' if stem_type == 'v1' else 'se', + drop_path=dpr[i], + )] + self.num_features = out_chs + if downsample: + current_stride *= 2 + self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) + + if num_classes > 0: + self.head = ClassifierHead( + self.num_features, + num_classes=num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + use_last_conv=use_last_conv, + class_expand=class_expand, + use_lab=use_lab + ) + else: + if global_pool == 'avg': + self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + else: + self.head = nn.Identity() + + for n, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).blocks\.(\d+)', + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + if num_classes > 0: + self.head = ClassifierHead( + self.num_features, + num_classes=num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + use_last_conv=self.use_last_conv, + class_expand=self.class_expand, + use_lab=self.use_lab) + else: + if global_pool: + self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + else: + self.head = nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + return self.stages(x) + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +model_cfgs = dict( + # PP-HGNet + hgnet_tiny={ + "stem_type": 'v1', + "stem_chs": [48, 48, 96], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num + "stage1": [96, 96, 224, 1, False, False, 3, 5], + "stage2": [224, 128, 448, 1, True, False, 3, 5], + "stage3": [448, 160, 512, 2, True, False, 3, 5], + "stage4": [512, 192, 768, 1, True, False, 3, 5], + }, + hgnet_small={ + "stem_type": 'v1', + "stem_chs": [64, 64, 128], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num + "stage1": [128, 128, 256, 1, False, False, 3, 6], + "stage2": [256, 160, 512, 1, True, False, 3, 6], + "stage3": [512, 192, 768, 2, True, False, 3, 6], + "stage4": [768, 224, 1024, 1, True, False, 3, 6], + }, + hgnet_base={ + "stem_type": 'v1', + "stem_chs": [96, 96, 160], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num + "stage1": [160, 192, 320, 1, False, False, 3, 7], + "stage2": [320, 224, 640, 2, True, False, 3, 7], + "stage3": [640, 256, 960, 3, True, False, 3, 7], + "stage4": [960, 288, 1280, 2, True, False, 3, 7], + }, + # PP-HGNetv2 + hgnetv2_b0={ + "stem_type": 'v2', + "stem_chs": [16, 16], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num + "stage1": [16, 16, 64, 1, False, False, 3, 3], + "stage2": [64, 32, 256, 1, True, False, 3, 3], + "stage3": [256, 64, 512, 2, True, True, 5, 3], + "stage4": [512, 128, 1024, 1, True, True, 5, 3], + }, + hgnetv2_b1={ + "stem_type": 'v2', + "stem_chs": [24, 32], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num + "stage1": [32, 32, 64, 1, False, False, 3, 3], + "stage2": [64, 48, 256, 1, True, False, 3, 3], + "stage3": [256, 96, 512, 2, True, True, 5, 3], + "stage4": [512, 192, 1024, 1, True, True, 5, 3], + }, + hgnetv2_b2={ + "stem_type": 'v2', + "stem_chs": [24, 32], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num + "stage1": [32, 32, 96, 1, False, False, 3, 4], + "stage2": [96, 64, 384, 1, True, False, 3, 4], + "stage3": [384, 128, 768, 3, True, True, 5, 4], + "stage4": [768, 256, 1536, 1, True, True, 5, 4], + }, + hgnetv2_b3={ + "stem_type": 'v2', + "stem_chs": [24, 32], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num + "stage1": [32, 32, 128, 1, False, False, 3, 5], + "stage2": [128, 64, 512, 1, True, False, 3, 5], + "stage3": [512, 128, 1024, 3, True, True, 5, 5], + "stage4": [1024, 256, 2048, 1, True, True, 5, 5], + }, + hgnetv2_b4={ + "stem_type": 'v2', + "stem_chs": [32, 48], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num + "stage1": [48, 48, 128, 1, False, False, 3, 6], + "stage2": [128, 96, 512, 1, True, False, 3, 6], + "stage3": [512, 192, 1024, 3, True, True, 5, 6], + "stage4": [1024, 384, 2048, 1, True, True, 5, 6], + }, + hgnetv2_b5={ + "stem_type": 'v2', + "stem_chs": [32, 64], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num + "stage1": [64, 64, 128, 1, False, False, 3, 6], + "stage2": [128, 128, 512, 2, True, False, 3, 6], + "stage3": [512, 256, 1024, 5, True, True, 5, 6], + "stage4": [1024, 512, 2048, 2, True, True, 5, 6], + }, + hgnetv2_b6={ + "stem_type": 'v2', + "stem_chs": [48, 96], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num + "stage1": [96, 96, 192, 2, False, False, 3, 6], + "stage2": [192, 192, 512, 3, True, False, 3, 6], + "stage3": [512, 384, 1024, 6, True, True, 5, 6], + "stage4": [1024, 768, 2048, 3, True, True, 5, 6], + }, +) + + +def _create_hgnet(variant, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) + return build_model_with_cfg( + HighPerfGpuNet, + variant, + pretrained, + model_cfg=model_cfgs[variant], + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs, + ) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.965, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'classifier': 'head.fc', 'first_conv': 'stem.stem1.conv', + 'test_crop_pct': 1.0, 'test_input_size': (3, 288, 288), + **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + 'hgnet_tiny.paddle_in1k': _cfg( + first_conv='stem.stem.0.conv', + hf_hub_id='timm/'), + 'hgnet_tiny.ssld_in1k': _cfg( + first_conv='stem.stem.0.conv', + hf_hub_id='timm/'), + 'hgnet_small.paddle_in1k': _cfg( + first_conv='stem.stem.0.conv', + hf_hub_id='timm/'), + 'hgnet_small.ssld_in1k': _cfg( + first_conv='stem.stem.0.conv', + hf_hub_id='timm/'), + 'hgnet_base.ssld_in1k': _cfg( + first_conv='stem.stem.0.conv', + hf_hub_id='timm/'), + 'hgnetv2_b0.ssld_stage2_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b0.ssld_stage1_in22k_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b1.ssld_stage2_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b1.ssld_stage1_in22k_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b2.ssld_stage2_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b2.ssld_stage1_in22k_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b3.ssld_stage2_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b3.ssld_stage1_in22k_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b4.ssld_stage2_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b4.ssld_stage1_in22k_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b5.ssld_stage2_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b5.ssld_stage1_in22k_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b6.ssld_stage2_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b6.ssld_stage1_in22k_in1k': _cfg( + hf_hub_id='timm/'), +}) + + +@register_model +def hgnet_tiny(pretrained=False, **kwargs) -> HighPerfGpuNet: + return _create_hgnet('hgnet_tiny', pretrained=pretrained, **kwargs) + + +@register_model +def hgnet_small(pretrained=False, **kwargs) -> HighPerfGpuNet: + return _create_hgnet('hgnet_small', pretrained=pretrained, **kwargs) + + +@register_model +def hgnet_base(pretrained=False, **kwargs) -> HighPerfGpuNet: + return _create_hgnet('hgnet_base', pretrained=pretrained, **kwargs) + + +@register_model +def hgnetv2_b0(pretrained=False, **kwargs) -> HighPerfGpuNet: + return _create_hgnet('hgnetv2_b0', pretrained=pretrained, use_lab=True, **kwargs) + + +@register_model +def hgnetv2_b1(pretrained=False, **kwargs) -> HighPerfGpuNet: + return _create_hgnet('hgnetv2_b1', pretrained=pretrained, use_lab=True, **kwargs) + + +@register_model +def hgnetv2_b2(pretrained=False, **kwargs) -> HighPerfGpuNet: + return _create_hgnet('hgnetv2_b2', pretrained=pretrained, use_lab=True, **kwargs) + + +@register_model +def hgnetv2_b3(pretrained=False, **kwargs) -> HighPerfGpuNet: + return _create_hgnet('hgnetv2_b3', pretrained=pretrained, use_lab=True, **kwargs) + + +@register_model +def hgnetv2_b4(pretrained=False, **kwargs) -> HighPerfGpuNet: + return _create_hgnet('hgnetv2_b4', pretrained=pretrained, **kwargs) + + +@register_model +def hgnetv2_b5(pretrained=False, **kwargs) -> HighPerfGpuNet: + return _create_hgnet('hgnetv2_b5', pretrained=pretrained, **kwargs) + + +@register_model +def hgnetv2_b6(pretrained=False, **kwargs) -> HighPerfGpuNet: + return _create_hgnet('hgnetv2_b6', pretrained=pretrained, **kwargs) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/hrnet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..20ea7674be3c7012721505f7078ef0383cab0f73 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/hrnet.py @@ -0,0 +1,974 @@ +""" HRNet + +Copied from https://github.com/HRNet/HRNet-Image-Classification + +Original header: + Copyright (c) Microsoft + Licensed under the MIT License. + Written by Bin Xiao (Bin.Xiao@microsoft.com) + Modified by Ke Sun (sunk@mail.ustc.edu.cn) +""" +import logging +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import create_classifier +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._features import FeatureInfo +from ._registry import register_model, generate_default_cfgs +from .resnet import BasicBlock, Bottleneck # leveraging ResNet block_types w/ additional features like SE + +__all__ = ['HighResolutionNet', 'HighResolutionNetFeatures'] # model_registry will add each entrypoint fn to this + +_BN_MOMENTUM = 0.1 +_logger = logging.getLogger(__name__) + + +cfg_cls = dict( + hrnet_w18_small=dict( + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(1,), + num_channels=(32,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(2, 2), + num_channels=(16, 32), + fuse_method='SUM' + ), + stage3=dict( + num_modules=1, + num_branches=3, + block_type='BASIC', + num_blocks=(2, 2, 2), + num_channels=(16, 32, 64), + fuse_method='SUM' + ), + stage4=dict( + num_modules=1, + num_branches=4, + block_type='BASIC', + num_blocks=(2, 2, 2, 2), + num_channels=(16, 32, 64, 128), + fuse_method='SUM', + ), + ), + + hrnet_w18_small_v2=dict( + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(2,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(2, 2), + num_channels=(18, 36), + fuse_method='SUM' + ), + stage3=dict( + num_modules=3, + num_branches=3, + block_type='BASIC', + num_blocks=(2, 2, 2), + num_channels=(18, 36, 72), + fuse_method='SUM' + ), + stage4=dict( + num_modules=2, + num_branches=4, + block_type='BASIC', + num_blocks=(2, 2, 2, 2), + num_channels=(18, 36, 72, 144), + fuse_method='SUM', + ), + ), + + hrnet_w18=dict( + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(4,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(4, 4), + num_channels=(18, 36), + fuse_method='SUM' + ), + stage3=dict( + num_modules=4, + num_branches=3, + block_type='BASIC', + num_blocks=(4, 4, 4), + num_channels=(18, 36, 72), + fuse_method='SUM' + ), + stage4=dict( + num_modules=3, + num_branches=4, + block_type='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(18, 36, 72, 144), + fuse_method='SUM', + ), + ), + + hrnet_w30=dict( + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(4,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(4, 4), + num_channels=(30, 60), + fuse_method='SUM' + ), + stage3=dict( + num_modules=4, + num_branches=3, + block_type='BASIC', + num_blocks=(4, 4, 4), + num_channels=(30, 60, 120), + fuse_method='SUM' + ), + stage4=dict( + num_modules=3, + num_branches=4, + block_type='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(30, 60, 120, 240), + fuse_method='SUM', + ), + ), + + hrnet_w32=dict( + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(4,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(4, 4), + num_channels=(32, 64), + fuse_method='SUM' + ), + stage3=dict( + num_modules=4, + num_branches=3, + block_type='BASIC', + num_blocks=(4, 4, 4), + num_channels=(32, 64, 128), + fuse_method='SUM' + ), + stage4=dict( + num_modules=3, + num_branches=4, + block_type='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(32, 64, 128, 256), + fuse_method='SUM', + ), + ), + + hrnet_w40=dict( + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(4,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(4, 4), + num_channels=(40, 80), + fuse_method='SUM' + ), + stage3=dict( + num_modules=4, + num_branches=3, + block_type='BASIC', + num_blocks=(4, 4, 4), + num_channels=(40, 80, 160), + fuse_method='SUM' + ), + stage4=dict( + num_modules=3, + num_branches=4, + block_type='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(40, 80, 160, 320), + fuse_method='SUM', + ), + ), + + hrnet_w44=dict( + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(4,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(4, 4), + num_channels=(44, 88), + fuse_method='SUM' + ), + stage3=dict( + num_modules=4, + num_branches=3, + block_type='BASIC', + num_blocks=(4, 4, 4), + num_channels=(44, 88, 176), + fuse_method='SUM' + ), + stage4=dict( + num_modules=3, + num_branches=4, + block_type='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(44, 88, 176, 352), + fuse_method='SUM', + ), + ), + + hrnet_w48=dict( + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(4,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(4, 4), + num_channels=(48, 96), + fuse_method='SUM' + ), + stage3=dict( + num_modules=4, + num_branches=3, + block_type='BASIC', + num_blocks=(4, 4, 4), + num_channels=(48, 96, 192), + fuse_method='SUM' + ), + stage4=dict( + num_modules=3, + num_branches=4, + block_type='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(48, 96, 192, 384), + fuse_method='SUM', + ), + ), + + hrnet_w64=dict( + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(4,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(4, 4), + num_channels=(64, 128), + fuse_method='SUM' + ), + stage3=dict( + num_modules=4, + num_branches=3, + block_type='BASIC', + num_blocks=(4, 4, 4), + num_channels=(64, 128, 256), + fuse_method='SUM' + ), + stage4=dict( + num_modules=3, + num_branches=4, + block_type='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(64, 128, 256, 512), + fuse_method='SUM', + ), + ) +) + + +class HighResolutionModule(nn.Module): + def __init__( + self, + num_branches, + block_types, + num_blocks, + num_in_chs, + num_channels, + fuse_method, + multi_scale_output=True, + ): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, + block_types, + num_blocks, + num_in_chs, + num_channels, + ) + + self.num_in_chs = num_in_chs + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, + block_types, + num_blocks, + num_channels, + ) + self.fuse_layers = self._make_fuse_layers() + self.fuse_act = nn.ReLU(False) + + def _check_branches(self, num_branches, block_types, num_blocks, num_in_chs, num_channels): + error_msg = '' + if num_branches != len(num_blocks): + error_msg = 'num_branches({}) <> num_blocks({})'.format(num_branches, len(num_blocks)) + elif num_branches != len(num_channels): + error_msg = 'num_branches({}) <> num_channels({})'.format(num_branches, len(num_channels)) + elif num_branches != len(num_in_chs): + error_msg = 'num_branches({}) <> num_in_chs({})'.format(num_branches, len(num_in_chs)) + if error_msg: + _logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block_type, num_blocks, num_channels, stride=1): + downsample = None + if stride != 1 or self.num_in_chs[branch_index] != num_channels[branch_index] * block_type.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.num_in_chs[branch_index], num_channels[branch_index] * block_type.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(num_channels[branch_index] * block_type.expansion, momentum=_BN_MOMENTUM), + ) + + layers = [block_type(self.num_in_chs[branch_index], num_channels[branch_index], stride, downsample)] + self.num_in_chs[branch_index] = num_channels[branch_index] * block_type.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block_type(self.num_in_chs[branch_index], num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block_type, num_blocks, num_channels): + branches = [] + for i in range(num_branches): + branches.append(self._make_one_branch(i, block_type, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return nn.Identity() + + num_branches = self.num_branches + num_in_chs = self.num_in_chs + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_in_chs[j], num_in_chs[i], 1, 1, 0, bias=False), + nn.BatchNorm2d(num_in_chs[i], momentum=_BN_MOMENTUM), + nn.Upsample(scale_factor=2 ** (j - i), mode='nearest'))) + elif j == i: + fuse_layer.append(nn.Identity()) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_out_chs_conv3x3 = num_in_chs[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_in_chs[j], num_out_chs_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_out_chs_conv3x3, momentum=_BN_MOMENTUM) + )) + else: + num_out_chs_conv3x3 = num_in_chs[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_in_chs[j], num_out_chs_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_out_chs_conv3x3, momentum=_BN_MOMENTUM), + nn.ReLU(False) + )) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_in_chs(self): + return self.num_in_chs + + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i, branch in enumerate(self.branches): + x[i] = branch(x[i]) + + x_fuse = [] + for i, fuse_outer in enumerate(self.fuse_layers): + y = None + for j, f in enumerate(fuse_outer): + if y is None: + y = f(x[j]) + else: + y = y + f(x[j]) + x_fuse.append(self.fuse_act(y)) + return x_fuse + + +class SequentialList(nn.Sequential): + + def __init__(self, *args): + super(SequentialList, self).__init__(*args) + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (List[torch.Tensor]) -> (List[torch.Tensor]) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> (List[torch.Tensor]) + pass + + def forward(self, x) -> List[torch.Tensor]: + for module in self: + x = module(x) + return x + + +@torch.jit.interface +class ModuleInterface(torch.nn.Module): + def forward(self, input: torch.Tensor) -> torch.Tensor: # `input` has a same name in Sequential forward + pass + + +block_types_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HighResolutionNet(nn.Module): + + def __init__( + self, + cfg, + in_chans=3, + num_classes=1000, + output_stride=32, + global_pool='avg', + drop_rate=0.0, + head='classification', + **kwargs, + ): + super(HighResolutionNet, self).__init__() + self.num_classes = num_classes + assert output_stride == 32 # FIXME support dilation + + cfg.update(**kwargs) + stem_width = cfg['stem_width'] + self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(stem_width, momentum=_BN_MOMENTUM) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(stem_width, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=_BN_MOMENTUM) + self.act2 = nn.ReLU(inplace=True) + + self.stage1_cfg = cfg['stage1'] + num_channels = self.stage1_cfg['num_channels'][0] + block_type = block_types_dict[self.stage1_cfg['block_type']] + num_blocks = self.stage1_cfg['num_blocks'][0] + self.layer1 = self._make_layer(block_type, 64, num_channels, num_blocks) + stage1_out_channel = block_type.expansion * num_channels + + self.stage2_cfg = cfg['stage2'] + num_channels = self.stage2_cfg['num_channels'] + block_type = block_types_dict[self.stage2_cfg['block_type']] + num_channels = [num_channels[i] * block_type.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels) + self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg['stage3'] + num_channels = self.stage3_cfg['num_channels'] + block_type = block_types_dict[self.stage3_cfg['block_type']] + num_channels = [num_channels[i] * block_type.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg['stage4'] + num_channels = self.stage4_cfg['num_channels'] + block_type = block_types_dict[self.stage4_cfg['block_type']] + num_channels = [num_channels[i] * block_type.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True) + + self.head = head + self.head_channels = None # set if _make_head called + head_conv_bias = cfg.pop('head_conv_bias', True) + if head == 'classification': + # Classification Head + self.num_features = 2048 + self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head( + pre_stage_channels, + conv_bias=head_conv_bias, + ) + self.global_pool, self.head_drop, self.classifier = create_classifier( + self.num_features, + self.num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + ) + else: + if head == 'incre': + self.num_features = 2048 + self.incre_modules, _, _ = self._make_head(pre_stage_channels, incre_only=True) + else: + self.num_features = 256 + self.incre_modules = None + self.global_pool = nn.Identity() + self.head_drop = nn.Identity() + self.classifier = nn.Identity() + + curr_stride = 2 + # module names aren't actually valid here, hook or FeatureNet based extraction would not work + self.feature_info = [dict(num_chs=64, reduction=curr_stride, module='stem')] + for i, c in enumerate(self.head_channels if self.head_channels else num_channels): + curr_stride *= 2 + c = c * 4 if self.head_channels else c # head block_type expansion factor of 4 + self.feature_info += [dict(num_chs=c, reduction=curr_stride, module=f'stage{i + 1}')] + + self.init_weights() + + def _make_head(self, pre_stage_channels, incre_only=False, conv_bias=True): + head_block_type = Bottleneck + self.head_channels = [32, 64, 128, 256] + + # Increasing the #channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + incre_modules = [] + for i, channels in enumerate(pre_stage_channels): + incre_modules.append(self._make_layer(head_block_type, channels, self.head_channels[i], 1, stride=1)) + incre_modules = nn.ModuleList(incre_modules) + if incre_only: + return incre_modules, None, None + + # downsampling modules + downsamp_modules = [] + for i in range(len(pre_stage_channels) - 1): + in_channels = self.head_channels[i] * head_block_type.expansion + out_channels = self.head_channels[i + 1] * head_block_type.expansion + downsamp_module = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=2, padding=1, bias=conv_bias), + nn.BatchNorm2d(out_channels, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + downsamp_modules.append(downsamp_module) + downsamp_modules = nn.ModuleList(downsamp_modules) + + final_layer = nn.Sequential( + nn.Conv2d( + in_channels=self.head_channels[3] * head_block_type.expansion, out_channels=self.num_features, + kernel_size=1, stride=1, padding=0, bias=conv_bias), + nn.BatchNorm2d(self.num_features, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + return incre_modules, downsamp_modules, final_layer + + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), + nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(nn.Identity()) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + _in_chs = num_channels_pre_layer[-1] + _out_chs = num_channels_cur_layer[i] if j == i - num_branches_pre else _in_chs + conv3x3s.append(nn.Sequential( + nn.Conv2d(_in_chs, _out_chs, 3, 2, 1, bias=False), + nn.BatchNorm2d(_out_chs, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block_type, inplanes, planes, block_types, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block_type.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block_type.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block_type.expansion, momentum=_BN_MOMENTUM), + ) + + layers = [block_type(inplanes, planes, stride, downsample)] + inplanes = planes * block_type.expansion + for i in range(1, block_types): + layers.append(block_type(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_in_chs, multi_scale_output=True): + num_modules = layer_config['num_modules'] + num_branches = layer_config['num_branches'] + num_blocks = layer_config['num_blocks'] + num_channels = layer_config['num_channels'] + block_type = block_types_dict[layer_config['block_type']] + fuse_method = layer_config['fuse_method'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + reset_multi_scale_output = multi_scale_output or i < num_modules - 1 + modules.append(HighResolutionModule( + num_branches, block_type, num_blocks, num_in_chs, num_channels, fuse_method, reset_multi_scale_output) + ) + num_in_chs = modules[-1].get_num_in_chs() + + return SequentialList(*modules), num_in_chs + + @torch.jit.ignore + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^conv[12]|bn[12]', + block_types=r'^(?:layer|stage|transition)(\d+)' if coarse else [ + (r'^layer(\d+)\.(\d+)', None), + (r'^stage(\d+)\.(\d+)', None), + (r'^transition(\d+)', (99999,)), + ], + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, "gradient checkpointing not supported" + + @torch.jit.ignore + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def stages(self, x) -> List[torch.Tensor]: + x = self.layer1(x) + + xl = [t(x) for i, t in enumerate(self.transition1)] + yl = self.stage2(xl) + + xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition2)] + yl = self.stage3(xl) + + xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition3)] + yl = self.stage4(xl) + return yl + + def forward_features(self, x): + # Stem + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + + # Stages + yl = self.stages(x) + if self.incre_modules is None or self.downsamp_modules is None: + return yl + + y = None + for i, incre in enumerate(self.incre_modules): + if y is None: + y = incre(yl[i]) + else: + down: ModuleInterface = self.downsamp_modules[i - 1] # needed for torchscript module indexing + y = incre(yl[i]) + down.forward(y) + + y = self.final_layer(y) + return y + + def forward_head(self, x, pre_logits: bool = False): + # Classification Head + x = self.global_pool(x) + x = self.head_drop(x) + return x if pre_logits else self.classifier(x) + + def forward(self, x): + y = self.forward_features(x) + x = self.forward_head(y) + return x + + +class HighResolutionNetFeatures(HighResolutionNet): + """HighResolutionNet feature extraction + + The design of HRNet makes it easy to grab feature maps, this class provides a simple wrapper to do so. + It would be more complicated to use the FeatureNet helpers. + + The `feature_location=incre` allows grabbing increased channel count features using part of the + classification head. If `feature_location=''` the default HRNet features are returned. First stem + conv is used for stride 2 features. + """ + + def __init__( + self, + cfg, + in_chans=3, + num_classes=1000, + output_stride=32, + global_pool='avg', + drop_rate=0.0, + feature_location='incre', + out_indices=(0, 1, 2, 3, 4), + **kwargs, + ): + assert feature_location in ('incre', '') + super(HighResolutionNetFeatures, self).__init__( + cfg, + in_chans=in_chans, + num_classes=num_classes, + output_stride=output_stride, + global_pool=global_pool, + drop_rate=drop_rate, + head=feature_location, + **kwargs, + ) + self.feature_info = FeatureInfo(self.feature_info, out_indices) + self._out_idx = {f['index'] for f in self.feature_info.get_dicts()} + + def forward_features(self, x): + assert False, 'Not supported' + + def forward(self, x) -> List[torch.tensor]: + out = [] + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + if 0 in self._out_idx: + out.append(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + x = self.stages(x) + if self.incre_modules is not None: + x = [incre(f) for f, incre in zip(x, self.incre_modules)] + for i, f in enumerate(x): + if i + 1 in self._out_idx: + out.append(f) + return out + + +def _create_hrnet(variant, pretrained=False, cfg_variant=None, **model_kwargs): + model_cls = HighResolutionNet + features_only = False + kwargs_filter = None + if model_kwargs.pop('features_only', False): + model_cls = HighResolutionNetFeatures + kwargs_filter = ('num_classes', 'global_pool') + features_only = True + cfg_variant = cfg_variant or variant + model = build_model_with_cfg( + model_cls, + variant, + pretrained, + model_cfg=cfg_cls[cfg_variant], + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **model_kwargs, + ) + if features_only: + model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg) + model.default_cfg = model.pretrained_cfg # backwards compat + return model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'hrnet_w18_small.gluon_in1k': _cfg(hf_hub_id='timm/', interpolation='bicubic'), + 'hrnet_w18_small.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w18_small_v2.gluon_in1k': _cfg(hf_hub_id='timm/', interpolation='bicubic'), + 'hrnet_w18_small_v2.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w18.ms_aug_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, + ), + 'hrnet_w18.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w30.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w32.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w40.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w44.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w48.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w64.ms_in1k': _cfg(hf_hub_id='timm/'), + + 'hrnet_w18_ssld.paddle_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, test_crop_pct=1.0, test_input_size=(3, 288, 288) + ), + 'hrnet_w48_ssld.paddle_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, test_crop_pct=1.0, test_input_size=(3, 288, 288) + ), +}) + + +@register_model +def hrnet_w18_small(pretrained=False, **kwargs) -> HighResolutionNet: + return _create_hrnet('hrnet_w18_small', pretrained, **kwargs) + + +@register_model +def hrnet_w18_small_v2(pretrained=False, **kwargs) -> HighResolutionNet: + return _create_hrnet('hrnet_w18_small_v2', pretrained, **kwargs) + + +@register_model +def hrnet_w18(pretrained=False, **kwargs) -> HighResolutionNet: + return _create_hrnet('hrnet_w18', pretrained, **kwargs) + + +@register_model +def hrnet_w30(pretrained=False, **kwargs) -> HighResolutionNet: + return _create_hrnet('hrnet_w30', pretrained, **kwargs) + + +@register_model +def hrnet_w32(pretrained=False, **kwargs) -> HighResolutionNet: + return _create_hrnet('hrnet_w32', pretrained, **kwargs) + + +@register_model +def hrnet_w40(pretrained=False, **kwargs) -> HighResolutionNet: + return _create_hrnet('hrnet_w40', pretrained, **kwargs) + + +@register_model +def hrnet_w44(pretrained=False, **kwargs) -> HighResolutionNet: + return _create_hrnet('hrnet_w44', pretrained, **kwargs) + + +@register_model +def hrnet_w48(pretrained=False, **kwargs) -> HighResolutionNet: + return _create_hrnet('hrnet_w48', pretrained, **kwargs) + + +@register_model +def hrnet_w64(pretrained=False, **kwargs) -> HighResolutionNet: + return _create_hrnet('hrnet_w64', pretrained, **kwargs) + + +@register_model +def hrnet_w18_ssld(pretrained=False, **kwargs) -> HighResolutionNet: + kwargs.setdefault('head_conv_bias', False) + return _create_hrnet('hrnet_w18_ssld', cfg_variant='hrnet_w18', pretrained=pretrained, **kwargs) + + +@register_model +def hrnet_w48_ssld(pretrained=False, **kwargs) -> HighResolutionNet: + kwargs.setdefault('head_conv_bias', False) + return _create_hrnet('hrnet_w48_ssld', cfg_variant='hrnet_w48', pretrained=pretrained, **kwargs) + diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/hub.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/hub.py new file mode 100644 index 0000000000000000000000000000000000000000..fdc3a921c591f3ab344ba777e07d454a9b324ce9 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/hub.py @@ -0,0 +1,4 @@ +from ._hub import * + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/inception_next.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/inception_next.py new file mode 100644 index 0000000000000000000000000000000000000000..f5d37db981e483480e0e274d3ad49684574e1f8b --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/inception_next.py @@ -0,0 +1,441 @@ +""" +InceptionNeXt paper: https://arxiv.org/abs/2303.16900 +Original implementation & weights from: https://github.com/sail-sg/inceptionnext +""" + +from functools import partial + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import trunc_normal_, DropPath, to_2tuple, get_padding, SelectAdaptivePool2d +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs + + +class InceptionDWConv2d(nn.Module): + """ Inception depthwise convolution + """ + + def __init__( + self, + in_chs, + square_kernel_size=3, + band_kernel_size=11, + branch_ratio=0.125, + dilation=1, + ): + super().__init__() + + gc = int(in_chs * branch_ratio) # channel numbers of a convolution branch + square_padding = get_padding(square_kernel_size, dilation=dilation) + band_padding = get_padding(band_kernel_size, dilation=dilation) + self.dwconv_hw = nn.Conv2d( + gc, gc, square_kernel_size, + padding=square_padding, dilation=dilation, groups=gc) + self.dwconv_w = nn.Conv2d( + gc, gc, (1, band_kernel_size), + padding=(0, band_padding), dilation=(1, dilation), groups=gc) + self.dwconv_h = nn.Conv2d( + gc, gc, (band_kernel_size, 1), + padding=(band_padding, 0), dilation=(dilation, 1), groups=gc) + self.split_indexes = (in_chs - 3 * gc, gc, gc, gc) + + def forward(self, x): + x_id, x_hw, x_w, x_h = torch.split(x, self.split_indexes, dim=1) + return torch.cat(( + x_id, + self.dwconv_hw(x_hw), + self.dwconv_w(x_w), + self.dwconv_h(x_h) + ), dim=1, + ) + + +class ConvMlp(nn.Module): + """ MLP using 1x1 convs that keeps spatial dims + copied from timm: https://github.com/huggingface/pytorch-image-models/blob/v0.6.11/timm/models/layers/mlp.py + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.ReLU, + norm_layer=None, + bias=True, + drop=0., + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + self.act = act_layer() + self.drop = nn.Dropout(drop) + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return x + + +class MlpClassifierHead(nn.Module): + """ MLP classification head + """ + + def __init__( + self, + dim, + num_classes=1000, + pool_type='avg', + mlp_ratio=3, + act_layer=nn.GELU, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0., + bias=True + ): + super().__init__() + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True) + in_features = dim * self.global_pool.feat_mult() + hidden_features = int(mlp_ratio * in_features) + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.norm = norm_layer(hidden_features) + self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.global_pool(x) + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.drop(x) + x = self.fc2(x) + return x + + +class MetaNeXtBlock(nn.Module): + """ MetaNeXtBlock Block + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + ls_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + dim, + dilation=1, + token_mixer=InceptionDWConv2d, + norm_layer=nn.BatchNorm2d, + mlp_layer=ConvMlp, + mlp_ratio=4, + act_layer=nn.GELU, + ls_init_value=1e-6, + drop_path=0., + + ): + super().__init__() + self.token_mixer = token_mixer(dim, dilation=dilation) + self.norm = norm_layer(dim) + self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.token_mixer(x) + x = self.norm(x) + x = self.mlp(x) + if self.gamma is not None: + x = x.mul(self.gamma.reshape(1, -1, 1, 1)) + x = self.drop_path(x) + shortcut + return x + + +class MetaNeXtStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + stride=2, + depth=2, + dilation=(1, 1), + drop_path_rates=None, + ls_init_value=1.0, + token_mixer=InceptionDWConv2d, + act_layer=nn.GELU, + norm_layer=None, + mlp_ratio=4, + ): + super().__init__() + self.grad_checkpointing = False + if stride > 1 or dilation[0] != dilation[1]: + self.downsample = nn.Sequential( + norm_layer(in_chs), + nn.Conv2d( + in_chs, + out_chs, + kernel_size=2, + stride=stride, + dilation=dilation[0], + ), + ) + else: + self.downsample = nn.Identity() + + drop_path_rates = drop_path_rates or [0.] * depth + stage_blocks = [] + for i in range(depth): + stage_blocks.append(MetaNeXtBlock( + dim=out_chs, + dilation=dilation[1], + drop_path=drop_path_rates[i], + ls_init_value=ls_init_value, + token_mixer=token_mixer, + act_layer=act_layer, + norm_layer=norm_layer, + mlp_ratio=mlp_ratio, + )) + self.blocks = nn.Sequential(*stage_blocks) + + def forward(self, x): + x = self.downsample(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class MetaNeXt(nn.Module): + r""" MetaNeXt + A PyTorch impl of : `InceptionNeXt: When Inception Meets ConvNeXt` - https://arxiv.org/abs/2303.16900 + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: (3, 3, 9, 3) + dims (tuple(int)): Feature dimension at each stage. Default: (96, 192, 384, 768) + token_mixers: Token mixer function. Default: nn.Identity + norm_layer: Normalization layer. Default: nn.BatchNorm2d + act_layer: Activation function for MLP. Default: nn.GELU + mlp_ratios (int or tuple(int)): MLP ratios. Default: (4, 4, 4, 3) + head_fn: classifier head + drop_rate (float): Head dropout rate + drop_path_rate (float): Stochastic depth rate. Default: 0. + ls_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + output_stride=32, + depths=(3, 3, 9, 3), + dims=(96, 192, 384, 768), + token_mixers=InceptionDWConv2d, + norm_layer=nn.BatchNorm2d, + act_layer=nn.GELU, + mlp_ratios=(4, 4, 4, 3), + head_fn=MlpClassifierHead, + drop_rate=0., + drop_path_rate=0., + ls_init_value=1e-6, + ): + super().__init__() + + num_stage = len(depths) + if not isinstance(token_mixers, (list, tuple)): + token_mixers = [token_mixers] * num_stage + if not isinstance(mlp_ratios, (list, tuple)): + mlp_ratios = [mlp_ratios] * num_stage + self.num_classes = num_classes + self.global_pool = global_pool + self.drop_rate = drop_rate + self.feature_info = [] + + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + norm_layer(dims[0]) + ) + + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + prev_chs = dims[0] + curr_stride = 4 + dilation = 1 + # feature resolution stages, each consisting of multiple residual blocks + self.stages = nn.Sequential() + for i in range(num_stage): + stride = 2 if curr_stride == 2 or i > 0 else 1 + if curr_stride >= output_stride and stride > 1: + dilation *= stride + stride = 1 + curr_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + out_chs = dims[i] + self.stages.append(MetaNeXtStage( + prev_chs, + out_chs, + stride=stride if i > 0 else 1, + dilation=(first_dilation, dilation), + depth=depths[i], + drop_path_rates=dp_rates[i], + ls_init_value=ls_init_value, + act_layer=act_layer, + token_mixer=token_mixers[i], + norm_layer=norm_layer, + mlp_ratio=mlp_ratios[i], + )) + prev_chs = out_chs + self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] + self.num_features = prev_chs + if self.num_classes > 0: + if issubclass(head_fn, MlpClassifierHead): + assert self.global_pool, 'Cannot disable global pooling with MLP head present.' + self.head = head_fn(self.num_features, num_classes, pool_type=self.global_pool, drop=drop_rate) + else: + if self.global_pool: + self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) + else: + self.head = nn.Identity() + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.downsample', (0,)), # blocks + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + ] + ) + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc2 + + def reset_classifier(self, num_classes=0, global_pool=None, head_fn=MlpClassifierHead): + if global_pool is not None: + self.global_pool = global_pool + if num_classes > 0: + if issubclass(head_fn, MlpClassifierHead): + assert self.global_pool, 'Cannot disable global pooling with MLP head present.' + self.head = head_fn(self.num_features, num_classes, pool_type=self.global_pool, drop=self.drop_rate) + else: + if self.global_pool: + self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) + else: + self.head = nn.Identity() + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return set() + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if pre_logits: + if hasattr(self.head, 'global_pool'): + x = self.head.global_pool(x) + return x + return self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head.fc2', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'inception_next_tiny.sail_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth', + ), + 'inception_next_small.sail_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth', + ), + 'inception_next_base.sail_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth', + crop_pct=0.95, + ), + 'inception_next_base.sail_in1k_384': _cfg( + hf_hub_id='timm/', + # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), +}) + + +def _create_inception_next(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + MetaNeXt, variant, pretrained, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **kwargs, + ) + return model + + +@register_model +def inception_next_tiny(pretrained=False, **kwargs): + model_args = dict( + depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), + token_mixers=InceptionDWConv2d, + ) + return _create_inception_next('inception_next_tiny', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def inception_next_small(pretrained=False, **kwargs): + model_args = dict( + depths=(3, 3, 27, 3), dims=(96, 192, 384, 768), + token_mixers=InceptionDWConv2d, + ) + return _create_inception_next('inception_next_small', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def inception_next_base(pretrained=False, **kwargs): + model_args = dict( + depths=(3, 3, 27, 3), dims=(128, 256, 512, 1024), + token_mixers=InceptionDWConv2d, + ) + return _create_inception_next('inception_next_base', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/inception_resnet_v2.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/inception_resnet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..f4efaf520d1421499ec5d61a58520a8cb30e2443 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/inception_resnet_v2.py @@ -0,0 +1,341 @@ +""" Pytorch Inception-Resnet-V2 implementation +Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is +based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) +""" +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.layers import create_classifier, ConvNormAct +from ._builder import build_model_with_cfg +from ._manipulate import flatten_modules +from ._registry import register_model, generate_default_cfgs, register_model_deprecations + +__all__ = ['InceptionResnetV2'] + + +class Mixed_5b(nn.Module): + def __init__(self, conv_block=None): + super(Mixed_5b, self).__init__() + conv_block = conv_block or ConvNormAct + + self.branch0 = conv_block(192, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + conv_block(192, 48, kernel_size=1, stride=1), + conv_block(48, 64, kernel_size=5, stride=1, padding=2) + ) + + self.branch2 = nn.Sequential( + conv_block(192, 64, kernel_size=1, stride=1), + conv_block(64, 96, kernel_size=3, stride=1, padding=1), + conv_block(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + conv_block(192, 64, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block35(nn.Module): + def __init__(self, scale=1.0, conv_block=None): + super(Block35, self).__init__() + self.scale = scale + conv_block = conv_block or ConvNormAct + + self.branch0 = conv_block(320, 32, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + conv_block(320, 32, kernel_size=1, stride=1), + conv_block(32, 32, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + conv_block(320, 32, kernel_size=1, stride=1), + conv_block(32, 48, kernel_size=3, stride=1, padding=1), + conv_block(48, 64, kernel_size=3, stride=1, padding=1) + ) + + self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) + self.act = nn.ReLU() + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.act(out) + return out + + +class Mixed_6a(nn.Module): + def __init__(self, conv_block=None): + super(Mixed_6a, self).__init__() + conv_block = conv_block or ConvNormAct + + self.branch0 = conv_block(320, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + conv_block(320, 256, kernel_size=1, stride=1), + conv_block(256, 256, kernel_size=3, stride=1, padding=1), + conv_block(256, 384, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Block17(nn.Module): + def __init__(self, scale=1.0, conv_block=None): + super(Block17, self).__init__() + self.scale = scale + conv_block = conv_block or ConvNormAct + + self.branch0 = conv_block(1088, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + conv_block(1088, 128, kernel_size=1, stride=1), + conv_block(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)), + conv_block(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)) + ) + + self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) + self.act = nn.ReLU() + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.act(out) + return out + + +class Mixed_7a(nn.Module): + def __init__(self, conv_block=None): + super(Mixed_7a, self).__init__() + conv_block = conv_block or ConvNormAct + + self.branch0 = nn.Sequential( + conv_block(1088, 256, kernel_size=1, stride=1), + conv_block(256, 384, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + conv_block(1088, 256, kernel_size=1, stride=1), + conv_block(256, 288, kernel_size=3, stride=2) + ) + + self.branch2 = nn.Sequential( + conv_block(1088, 256, kernel_size=1, stride=1), + conv_block(256, 288, kernel_size=3, stride=1, padding=1), + conv_block(288, 320, kernel_size=3, stride=2) + ) + + self.branch3 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block8(nn.Module): + + def __init__(self, scale=1.0, no_relu=False, conv_block=None): + super(Block8, self).__init__() + self.scale = scale + conv_block = conv_block or ConvNormAct + + self.branch0 = conv_block(2080, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + conv_block(2080, 192, kernel_size=1, stride=1), + conv_block(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)), + conv_block(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + ) + + self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) + self.relu = None if no_relu else nn.ReLU() + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + if self.relu is not None: + out = self.relu(out) + return out + + +class InceptionResnetV2(nn.Module): + def __init__( + self, + num_classes=1000, + in_chans=3, + drop_rate=0., + output_stride=32, + global_pool='avg', + norm_layer='batchnorm2d', + norm_eps=1e-3, + act_layer='relu', + ): + super(InceptionResnetV2, self).__init__() + self.num_classes = num_classes + self.num_features = 1536 + assert output_stride == 32 + conv_block = partial( + ConvNormAct, + padding=0, + norm_layer=norm_layer, + act_layer=act_layer, + norm_kwargs=dict(eps=norm_eps), + act_kwargs=dict(inplace=True), + ) + + self.conv2d_1a = conv_block(in_chans, 32, kernel_size=3, stride=2) + self.conv2d_2a = conv_block(32, 32, kernel_size=3, stride=1) + self.conv2d_2b = conv_block(32, 64, kernel_size=3, stride=1, padding=1) + self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')] + + self.maxpool_3a = nn.MaxPool2d(3, stride=2) + self.conv2d_3b = conv_block(64, 80, kernel_size=1, stride=1) + self.conv2d_4a = conv_block(80, 192, kernel_size=3, stride=1) + self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')] + + self.maxpool_5a = nn.MaxPool2d(3, stride=2) + self.mixed_5b = Mixed_5b(conv_block=conv_block) + self.repeat = nn.Sequential(*[Block35(scale=0.17, conv_block=conv_block) for _ in range(10)]) + self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')] + + self.mixed_6a = Mixed_6a(conv_block=conv_block) + self.repeat_1 = nn.Sequential(*[Block17(scale=0.10, conv_block=conv_block) for _ in range(20)]) + self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')] + + self.mixed_7a = Mixed_7a(conv_block=conv_block) + self.repeat_2 = nn.Sequential(*[Block8(scale=0.20, conv_block=conv_block) for _ in range(9)]) + + self.block8 = Block8(no_relu=True, conv_block=conv_block) + self.conv2d_7b = conv_block(2080, self.num_features, kernel_size=1, stride=1) + self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')] + + self.global_pool, self.head_drop, self.classif = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + module_map = {k: i for i, (k, _) in enumerate(flatten_modules(self.named_children(), prefix=()))} + module_map.pop(('classif',)) + + def _matcher(name): + if any([name.startswith(n) for n in ('conv2d_1', 'conv2d_2')]): + return 0 + elif any([name.startswith(n) for n in ('conv2d_3', 'conv2d_4')]): + return 1 + elif any([name.startswith(n) for n in ('block8', 'conv2d_7')]): + return len(module_map) + 1 + else: + for k in module_map.keys(): + if k == tuple(name.split('.')[:len(k)]): + return module_map[k] + return float('inf') + return _matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, "checkpointing not supported" + + @torch.jit.ignore + def get_classifier(self): + return self.classif + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv2d_1a(x) + x = self.conv2d_2a(x) + x = self.conv2d_2b(x) + x = self.maxpool_3a(x) + x = self.conv2d_3b(x) + x = self.conv2d_4a(x) + x = self.maxpool_5a(x) + x = self.mixed_5b(x) + x = self.repeat(x) + x = self.mixed_6a(x) + x = self.repeat_1(x) + x = self.mixed_7a(x) + x = self.repeat_2(x) + x = self.block8(x) + x = self.conv2d_7b(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + x = self.global_pool(x) + x = self.head_drop(x) + return x if pre_logits else self.classif(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_inception_resnet_v2(variant, pretrained=False, **kwargs): + return build_model_with_cfg(InceptionResnetV2, variant, pretrained, **kwargs) + + +default_cfgs = generate_default_cfgs({ + # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz + 'inception_resnet_v2.tf_in1k': { + 'hf_hub_id': 'timm/', + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.8975, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', + }, + # As per https://arxiv.org/abs/1705.07204 and + # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz + 'inception_resnet_v2.tf_ens_adv_in1k': { + 'hf_hub_id': 'timm/', + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.8975, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', + } +}) + + +@register_model +def inception_resnet_v2(pretrained=False, **kwargs) -> InceptionResnetV2: + return _create_inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs) + + +register_model_deprecations(__name__, { + 'ens_adv_inception_resnet_v2': 'inception_resnet_v2.tf_ens_adv_in1k', +}) \ No newline at end of file diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/mobilenetv3.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/mobilenetv3.py new file mode 100644 index 0000000000000000000000000000000000000000..2d197a9d9c73061348d6928e43602d91a40ff81a --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/mobilenetv3.py @@ -0,0 +1,845 @@ +""" MobileNet V3 + +A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. + +Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 + +Hacked together by / Copyright 2019, Ross Wightman +""" +from functools import partial +from typing import Callable, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.layers import SelectAdaptivePool2d, Linear, LayerType, PadType, create_conv2d, get_norm_act_layer +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite +from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ + round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT +from ._features import FeatureInfo, FeatureHooks +from ._manipulate import checkpoint_seq +from ._registry import generate_default_cfgs, register_model, register_model_deprecations + +__all__ = ['MobileNetV3', 'MobileNetV3Features'] + + +class MobileNetV3(nn.Module): + """ MobiletNet-V3 + + Based on my EfficientNet implementation and building blocks, this model utilizes the MobileNet-v3 specific + 'efficient head', where global pooling is done before the head convolution without a final batch-norm + layer before the classifier. + + Paper: `Searching for MobileNetV3` - https://arxiv.org/abs/1905.02244 + + Other architectures utilizing MobileNet-V3 efficient head that are supported by this impl include: + * HardCoRe-NAS - https://arxiv.org/abs/2102.11646 (defn in hardcorenas.py uses this class) + * FBNet-V3 - https://arxiv.org/abs/2006.02049 + * LCNet - https://arxiv.org/abs/2109.15099 + """ + + def __init__( + self, + block_args: BlockArgs, + num_classes: int = 1000, + in_chans: int = 3, + stem_size: int = 16, + fix_stem: bool = False, + num_features: int = 1280, + head_bias: bool = True, + pad_type: PadType = '', + act_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + se_layer: Optional[LayerType] = None, + se_from_exp: bool = True, + round_chs_fn: Callable = round_channels, + drop_rate: float = 0., + drop_path_rate: float = 0., + global_pool: str = 'avg', + ): + """ + Args: + block_args: Arguments for blocks of the network. + num_classes: Number of classes for classification head. + in_chans: Number of input image channels. + stem_size: Number of output channels of the initial stem convolution. + fix_stem: If True, don't scale stem by round_chs_fn. + num_features: Number of output channels of the conv head layer. + head_bias: If True, add a learnable bias to the conv head layer. + pad_type: Type of padding to use for convolution layers. + act_layer: Type of activation layer. + norm_layer: Type of normalization layer. + se_layer: Type of Squeeze-and-Excite layer. + se_from_exp: If True, calculate SE channel reduction from expanded mid channels. + round_chs_fn: Callable to round number of filters based on depth multiplier. + drop_rate: Dropout rate. + drop_path_rate: Stochastic depth rate. + global_pool: Type of pooling to use for global pooling features of the FC head. + """ + super(MobileNetV3, self).__init__() + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + se_layer = se_layer or SqueezeExcite + self.num_classes = num_classes + self.num_features = num_features + self.drop_rate = drop_rate + self.grad_checkpointing = False + + # Stem + if not fix_stem: + stem_size = round_chs_fn(stem_size) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_act_layer(stem_size, inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + output_stride=32, + pad_type=pad_type, + round_chs_fn=round_chs_fn, + se_from_exp=se_from_exp, + act_layer=act_layer, + norm_layer=norm_layer, + se_layer=se_layer, + drop_path_rate=drop_path_rate, + ) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = builder.features + head_chs = builder.in_chs + + # Head + Pooling + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + num_pooled_chs = head_chs * self.global_pool.feat_mult() + self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) + self.act2 = act_layer(inplace=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled + self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + efficientnet_init_weights(self) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1] + layers.extend(self.blocks) + layers.extend([self.global_pool, self.conv_head, self.act2]) + layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False): + return dict( + stem=r'^conv_stem|bn1', + blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)' + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): + self.num_classes = num_classes + # cannot meaningfully change pooling of efficient head after creation + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled + self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv_stem(x) + x = self.bn1(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x, flatten=True) + else: + x = self.blocks(x) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + x = self.flatten(x) + if pre_logits: + return x + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +class MobileNetV3Features(nn.Module): + """ MobileNetV3 Feature Extractor + + A work-in-progress feature extraction module for MobileNet-V3 to use as a backbone for segmentation + and object detection models. + """ + + def __init__( + self, + block_args: BlockArgs, + out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + feature_location: str = 'bottleneck', + in_chans: int = 3, + stem_size: int = 16, + fix_stem: bool = False, + output_stride: int = 32, + pad_type: PadType = '', + round_chs_fn: Callable = round_channels, + se_from_exp: bool = True, + act_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + se_layer: Optional[LayerType] = None, + drop_rate: float = 0., + drop_path_rate: float = 0., + ): + """ + Args: + block_args: Arguments for blocks of the network. + out_indices: Output from stages at indices. + feature_location: Location of feature before/after each block, must be in ['bottleneck', 'expansion'] + in_chans: Number of input image channels. + stem_size: Number of output channels of the initial stem convolution. + fix_stem: If True, don't scale stem by round_chs_fn. + output_stride: Output stride of the network. + pad_type: Type of padding to use for convolution layers. + round_chs_fn: Callable to round number of filters based on depth multiplier. + se_from_exp: If True, calculate SE channel reduction from expanded mid channels. + act_layer: Type of activation layer. + norm_layer: Type of normalization layer. + se_layer: Type of Squeeze-and-Excite layer. + drop_rate: Dropout rate. + drop_path_rate: Stochastic depth rate. + """ + super(MobileNetV3Features, self).__init__() + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + se_layer = se_layer or SqueezeExcite + self.drop_rate = drop_rate + self.grad_checkpointing = False + + # Stem + if not fix_stem: + stem_size = round_chs_fn(stem_size) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size) + self.act1 = act_layer(inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + output_stride=output_stride, + pad_type=pad_type, + round_chs_fn=round_chs_fn, + se_from_exp=se_from_exp, + act_layer=act_layer, + norm_layer=norm_layer, + se_layer=se_layer, + drop_path_rate=drop_path_rate, + feature_location=feature_location, + ) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = FeatureInfo(builder.features, out_indices) + self._stage_out_idx = {f['stage']: f['index'] for f in self.feature_info.get_dicts()} + + efficientnet_init_weights(self) + + # Register feature extraction hooks with FeatureHooks helper + self.feature_hooks = None + if feature_location != 'bottleneck': + hooks = self.feature_info.get_dicts(keys=('module', 'hook_type')) + self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + self.grad_checkpointing = enable + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + if self.feature_hooks is None: + features = [] + if 0 in self._stage_out_idx: + features.append(x) # add stem out + for i, b in enumerate(self.blocks): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(b, x) + else: + x = b(x) + if i + 1 in self._stage_out_idx: + features.append(x) + return features + else: + self.blocks(x) + out = self.feature_hooks.get_output(x.device) + return list(out.values()) + + +def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV3: + features_mode = '' + model_cls = MobileNetV3 + kwargs_filter = None + if kwargs.pop('features_only', False): + if 'feature_cfg' in kwargs: + features_mode = 'cfg' + else: + kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool') + model_cls = MobileNetV3Features + features_mode = 'cls' + + model = build_model_with_cfg( + model_cls, + variant, + pretrained, + features_only=features_mode == 'cfg', + pretrained_strict=features_mode != 'cls', + kwargs_filter=kwargs_filter, + **kwargs, + ) + if features_mode == 'cls': + model.default_cfg = pretrained_cfg_for_features(model.default_cfg) + return model + + +def _gen_mobilenet_v3_rw(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3: + """Creates a MobileNet-V3 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + head_bias=False, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid'), + **kwargs, + ) + model = _create_mnv3(variant, pretrained, **model_kwargs) + return model + + +def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3: + """Creates a MobileNet-V3 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + if 'small' in variant: + num_features = 1024 + if 'minimal' in variant: + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16'], + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'], + # stage 2, 28x28 in + ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'], + # stage 3, 14x14 in + ['ir_r2_k3_s1_e3_c48'], + # stage 4, 14x14in + ['ir_r3_k3_s2_e6_c96'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], + ] + else: + act_layer = resolve_act_layer(kwargs, 'hard_swish') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu + # stage 2, 28x28 in + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish + # stage 3, 14x14 in + ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish + # stage 4, 14x14in + ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], # hard-swish + ] + else: + num_features = 1280 + if 'minimal' in variant: + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k3_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112'], + # stage 5, 14x14in + ['ir_r3_k3_s2_e6_c160'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + else: + act_layer = resolve_act_layer(kwargs, 'hard_swish') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=16, + fix_stem=channel_multiplier < 0.75, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=act_layer, + se_layer=se_layer, + **kwargs, + ) + model = _create_mnv3(variant, pretrained, **model_kwargs) + return model + + +def _gen_fbnetv3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs): + """ FBNetV3 + Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining` + - https://arxiv.org/abs/2006.02049 + FIXME untested, this is a preliminary impl of some FBNet-V3 variants. + """ + vl = variant.split('_')[-1] + if vl in ('a', 'b'): + stem_size = 16 + arch_def = [ + ['ds_r2_k3_s1_e1_c16'], + ['ir_r1_k5_s2_e4_c24', 'ir_r3_k5_s1_e2_c24'], + ['ir_r1_k5_s2_e5_c40_se0.25', 'ir_r4_k5_s1_e3_c40_se0.25'], + ['ir_r1_k5_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'], + ['ir_r1_k3_s1_e5_c120_se0.25', 'ir_r5_k5_s1_e3_c120_se0.25'], + ['ir_r1_k3_s2_e6_c184_se0.25', 'ir_r5_k5_s1_e4_c184_se0.25', 'ir_r1_k5_s1_e6_c224_se0.25'], + ['cn_r1_k1_s1_c1344'], + ] + elif vl == 'd': + stem_size = 24 + arch_def = [ + ['ds_r2_k3_s1_e1_c16'], + ['ir_r1_k3_s2_e5_c24', 'ir_r5_k3_s1_e2_c24'], + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r4_k3_s1_e3_c40_se0.25'], + ['ir_r1_k3_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'], + ['ir_r1_k3_s1_e5_c128_se0.25', 'ir_r6_k5_s1_e3_c128_se0.25'], + ['ir_r1_k3_s2_e6_c208_se0.25', 'ir_r5_k5_s1_e5_c208_se0.25', 'ir_r1_k5_s1_e6_c240_se0.25'], + ['cn_r1_k1_s1_c1440'], + ] + elif vl == 'g': + stem_size = 32 + arch_def = [ + ['ds_r3_k3_s1_e1_c24'], + ['ir_r1_k5_s2_e4_c40', 'ir_r4_k5_s1_e2_c40'], + ['ir_r1_k5_s2_e4_c56_se0.25', 'ir_r4_k5_s1_e3_c56_se0.25'], + ['ir_r1_k5_s2_e5_c104', 'ir_r4_k3_s1_e3_c104'], + ['ir_r1_k3_s1_e5_c160_se0.25', 'ir_r8_k5_s1_e3_c160_se0.25'], + ['ir_r1_k3_s2_e6_c264_se0.25', 'ir_r6_k5_s1_e5_c264_se0.25', 'ir_r2_k5_s1_e6_c288_se0.25'], + ['cn_r1_k1_s1_c1728'], + ] + else: + raise NotImplemented + round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.95) + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=round_chs_fn) + act_layer = resolve_act_layer(kwargs, 'hard_swish') + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=1984, + head_bias=False, + stem_size=stem_size, + round_chs_fn=round_chs_fn, + se_from_exp=False, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=act_layer, + se_layer=se_layer, + **kwargs, + ) + model = _create_mnv3(variant, pretrained, **model_kwargs) + return model + + +def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs): + """ LCNet + Essentially a MobileNet-V3 crossed with a MobileNet-V1 + + Paper: `PP-LCNet: A Lightweight CPU Convolutional Neural Network` - https://arxiv.org/abs/2109.15099 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['dsa_r1_k3_s1_c32'], + # stage 1, 112x112 in + ['dsa_r2_k3_s2_c64'], + # stage 2, 56x56 in + ['dsa_r2_k3_s2_c128'], + # stage 3, 28x28 in + ['dsa_r1_k3_s2_c256', 'dsa_r1_k5_s1_c256'], + # stage 4, 14x14in + ['dsa_r4_k5_s1_c256'], + # stage 5, 14x14in + ['dsa_r2_k5_s2_c512_se0.25'], + # 7x7 + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=16, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU), + num_features=1280, + **kwargs, + ) + model = _create_mnv3(variant, pretrained, **model_kwargs) + return model + + +def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs): + """ LCNet + Essentially a MobileNet-V3 crossed with a MobileNet-V1 + + Paper: `PP-LCNet: A Lightweight CPU Convolutional Neural Network` - https://arxiv.org/abs/2109.15099 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['dsa_r1_k3_s1_c32'], + # stage 1, 112x112 in + ['dsa_r2_k3_s2_c64'], + # stage 2, 56x56 in + ['dsa_r2_k3_s2_c128'], + # stage 3, 28x28 in + ['dsa_r1_k3_s2_c256', 'dsa_r1_k5_s1_c256'], + # stage 4, 14x14in + ['dsa_r4_k5_s1_c256'], + # stage 5, 14x14in + ['dsa_r2_k5_s2_c512_se0.25'], + # 7x7 + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=16, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU), + num_features=1280, + **kwargs, + ) + model = _create_mnv3(variant, pretrained, **model_kwargs) + return model + + +def _cfg(url: str = '', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'mobilenetv3_large_075.untrained': _cfg(url=''), + 'mobilenetv3_large_100.ra_in1k': _cfg( + interpolation='bicubic', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth', + hf_hub_id='timm/'), + 'mobilenetv3_large_100.miil_in21k_ft_in1k': _cfg( + interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), + origin_url='https://github.com/Alibaba-MIIL/ImageNet21K', + paper_ids='arXiv:2104.10972v4', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_1k_miil_78_0-66471c13.pth', + hf_hub_id='timm/'), + 'mobilenetv3_large_100.miil_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_in21k_miil-d71cc17b.pth', + hf_hub_id='timm/', + origin_url='https://github.com/Alibaba-MIIL/ImageNet21K', + paper_ids='arXiv:2104.10972v4', + interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221), + + 'mobilenetv3_small_050.lamb_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth', + hf_hub_id='timm/', + interpolation='bicubic'), + 'mobilenetv3_small_075.lamb_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth', + hf_hub_id='timm/', + interpolation='bicubic'), + 'mobilenetv3_small_100.lamb_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth', + hf_hub_id='timm/', + interpolation='bicubic'), + + 'mobilenetv3_rw.rmsp_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', + hf_hub_id='timm/', + interpolation='bicubic'), + + 'tf_mobilenetv3_large_075.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_large_100.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_large_minimal_100.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_075.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_100.in1k': _cfg( + url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_minimal_100.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + + 'fbnetv3_b.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth', + hf_hub_id='timm/', + test_input_size=(3, 256, 256), crop_pct=0.95), + 'fbnetv3_d.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth', + hf_hub_id='timm/', + test_input_size=(3, 256, 256), crop_pct=0.95), + 'fbnetv3_g.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth', + hf_hub_id='timm/', + input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95, pool_size=(8, 8)), + + "lcnet_035.untrained": _cfg(), + "lcnet_050.ra2_in1k": _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth', + hf_hub_id='timm/', + interpolation='bicubic', + ), + "lcnet_075.ra2_in1k": _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth', + hf_hub_id='timm/', + interpolation='bicubic', + ), + "lcnet_100.ra2_in1k": _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth', + hf_hub_id='timm/', + interpolation='bicubic', + ), + "lcnet_150.untrained": _cfg(), +}) + + +@register_model +def mobilenetv3_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_small_050(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_small_050', 0.50, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_small_075(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_small_100(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_rw(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V3 """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V3 """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V3 """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_minimal_100(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V3 """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_075(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V3 """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_100(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V3 """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_minimal_100(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V3 """ + kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT) + kwargs.setdefault('pad_type', 'same') + model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetv3_b(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ FBNetV3-B """ + model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetv3_d(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ FBNetV3-D """ + model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetv3_g(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ FBNetV3-G """ + model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs) + return model + + +@register_model +def lcnet_035(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ PP-LCNet 0.35""" + model = _gen_lcnet('lcnet_035', 0.35, pretrained=pretrained, **kwargs) + return model + + +@register_model +def lcnet_050(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ PP-LCNet 0.5""" + model = _gen_lcnet('lcnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def lcnet_075(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ PP-LCNet 1.0""" + model = _gen_lcnet('lcnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def lcnet_100(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ PP-LCNet 1.0""" + model = _gen_lcnet('lcnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def lcnet_150(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ PP-LCNet 1.5""" + model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs) + return model + + +register_model_deprecations(__name__, { + 'mobilenetv3_large_100_miil': 'mobilenetv3_large_100.miil_in21k_ft_in1k', + 'mobilenetv3_large_100_miil_in21k': 'mobilenetv3_large_100.miil_in21k', +}) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/nfnet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/nfnet.py new file mode 100644 index 0000000000000000000000000000000000000000..725b177c25b5f58e20247fbab16f03f3a3d6c1cc --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/nfnet.py @@ -0,0 +1,1030 @@ +""" Normalization Free Nets. NFNet, NF-RegNet, NF-ResNet (pre-activation) Models + +Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + +Paper: `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + +Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets + +Status: +* These models are a work in progress, experiments ongoing. +* Pretrained weights for two models so far, more to come. +* Model details updated to closer match official JAX code now that it's released +* NF-ResNet, NF-RegNet-B, and NFNet-F models supported + +Hacked together by / copyright Ross Wightman, 2021. +""" +from collections import OrderedDict +from dataclasses import dataclass, replace +from functools import partial +from typing import Callable, Tuple, Optional + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame, \ + get_act_layer, get_act_fn, get_attn, make_divisible +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._manipulate import checkpoint_seq +from ._registry import generate_default_cfgs, register_model + +__all__ = ['NormFreeNet', 'NfCfg'] # model_registry will add each entrypoint fn to this + + +@dataclass +class NfCfg: + depths: Tuple[int, int, int, int] + channels: Tuple[int, int, int, int] + alpha: float = 0.2 + stem_type: str = '3x3' + stem_chs: Optional[int] = None + group_size: Optional[int] = None + attn_layer: Optional[str] = None + attn_kwargs: dict = None + attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used + width_factor: float = 1.0 + bottle_ratio: float = 0.5 + num_features: int = 0 # num out_channels for final conv, no final_conv if 0 + ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal + reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle + extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models + gamma_in_act: bool = False + same_padding: bool = False + std_conv_eps: float = 1e-5 + skipinit: bool = False # disabled by default, non-trivial performance impact + zero_init_fc: bool = False + act_layer: str = 'silu' + + +class GammaAct(nn.Module): + def __init__(self, act_type='relu', gamma: float = 1.0, inplace=False): + super().__init__() + self.act_fn = get_act_fn(act_type) + self.gamma = gamma + self.inplace = inplace + + def forward(self, x): + return self.act_fn(x, inplace=self.inplace).mul_(self.gamma) + + +def act_with_gamma(act_type, gamma: float = 1.): + def _create(inplace=False): + return GammaAct(act_type, gamma=gamma, inplace=inplace) + return _create + + +class DownsampleAvg(nn.Module): + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + conv_layer: Callable = ScaledStdConv2d, + ): + """ AvgPool Downsampling as in 'D' ResNet variants. Support for dilation.""" + super(DownsampleAvg, self).__init__() + avg_stride = stride if dilation == 1 else 1 + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + else: + self.pool = nn.Identity() + self.conv = conv_layer(in_chs, out_chs, 1, stride=1) + + def forward(self, x): + return self.conv(self.pool(x)) + + +@register_notrace_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301 +class NormFreeBlock(nn.Module): + """Normalization-Free pre-activation block. + """ + + def __init__( + self, + in_chs: int, + out_chs: Optional[int] = None, + stride: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + alpha: float = 1.0, + beta: float = 1.0, + bottle_ratio: float = 0.25, + group_size: Optional[int] = None, + ch_div: int = 1, + reg: bool = True, + extra_conv: bool = False, + skipinit: bool = False, + attn_layer: Optional[Callable] = None, + attn_gain: bool = 2.0, + act_layer: Optional[Callable] = None, + conv_layer: Callable = ScaledStdConv2d, + drop_path_rate: float = 0., + ): + super().__init__() + first_dilation = first_dilation or dilation + out_chs = out_chs or in_chs + # RegNet variants scale bottleneck from in_chs, otherwise scale from out_chs like ResNet + mid_chs = make_divisible(in_chs * bottle_ratio if reg else out_chs * bottle_ratio, ch_div) + groups = 1 if not group_size else mid_chs // group_size + if group_size and group_size % ch_div == 0: + mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error + self.alpha = alpha + self.beta = beta + self.attn_gain = attn_gain + + if in_chs != out_chs or stride != 1 or dilation != first_dilation: + self.downsample = DownsampleAvg( + in_chs, + out_chs, + stride=stride, + dilation=dilation, + first_dilation=first_dilation, + conv_layer=conv_layer, + ) + else: + self.downsample = None + + self.act1 = act_layer() + self.conv1 = conv_layer(in_chs, mid_chs, 1) + self.act2 = act_layer(inplace=True) + self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) + if extra_conv: + self.act2b = act_layer(inplace=True) + self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups) + else: + self.act2b = None + self.conv2b = None + if reg and attn_layer is not None: + self.attn = attn_layer(mid_chs) # RegNet blocks apply attn btw conv2 & 3 + else: + self.attn = None + self.act3 = act_layer() + self.conv3 = conv_layer(mid_chs, out_chs, 1, gain_init=1. if skipinit else 0.) + if not reg and attn_layer is not None: + self.attn_last = attn_layer(out_chs) # ResNet blocks apply attn after conv3 + else: + self.attn_last = None + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None + + def forward(self, x): + out = self.act1(x) * self.beta + + # shortcut branch + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(out) + + # residual branch + out = self.conv1(out) + out = self.conv2(self.act2(out)) + if self.conv2b is not None: + out = self.conv2b(self.act2b(out)) + if self.attn is not None: + out = self.attn_gain * self.attn(out) + out = self.conv3(self.act3(out)) + if self.attn_last is not None: + out = self.attn_gain * self.attn_last(out) + out = self.drop_path(out) + + if self.skipinit_gain is not None: + out.mul_(self.skipinit_gain) + out = out * self.alpha + shortcut + return out + + +def create_stem( + in_chs: int, + out_chs: int, + stem_type: str = '', + conv_layer: Optional[Callable] = None, + act_layer: Optional[Callable] = None, + preact_feature: bool = True, +): + stem_stride = 2 + stem_feature = dict(num_chs=out_chs, reduction=2, module='stem.conv') + stem = OrderedDict() + assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') + if 'deep' in stem_type: + if 'quad' in stem_type: + # 4 deep conv stack as in NFNet-F models + assert not 'pool' in stem_type + stem_chs = (out_chs // 8, out_chs // 4, out_chs // 2, out_chs) + strides = (2, 1, 1, 2) + stem_stride = 4 + stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv3') + else: + if 'tiered' in stem_type: + stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # 'T' resnets in resnet.py + else: + stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets + strides = (2, 1, 1) + stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv2') + last_idx = len(stem_chs) - 1 + for i, (c, s) in enumerate(zip(stem_chs, strides)): + stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s) + if i != last_idx: + stem[f'act{i + 2}'] = act_layer(inplace=True) + in_chs = c + elif '3x3' in stem_type: + # 3x3 stem conv as in RegNet + stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2) + else: + # 7x7 stem conv as in ResNet + stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2) + + if 'pool' in stem_type: + stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1) + stem_stride = 4 + + return nn.Sequential(stem), stem_stride, stem_feature + + +# from https://github.com/deepmind/deepmind-research/tree/master/nfnets +_nonlin_gamma = dict( + identity=1.0, + celu=1.270926833152771, + elu=1.2716004848480225, + gelu=1.7015043497085571, + leaky_relu=1.70590341091156, + log_sigmoid=1.9193484783172607, + log_softmax=1.0002083778381348, + relu=1.7139588594436646, + relu6=1.7131484746932983, + selu=1.0008515119552612, + sigmoid=4.803835391998291, + silu=1.7881293296813965, + softsign=2.338853120803833, + softplus=1.9203323125839233, + tanh=1.5939117670059204, +) + + +class NormFreeNet(nn.Module): + """ Normalization-Free Network + + As described in : + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + and + `High-Performance Large-Scale Image Recognition Without Normalization` - https://arxiv.org/abs/2102.06171 + + This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and + the (preact) ResNet models described earlier in the paper. + + There are a few differences: + * channels are rounded to be divisible by 8 by default (keep tensor core kernels happy), + this changes channel dim and param counts slightly from the paper models + * activation correcting gamma constants are moved into the ScaledStdConv as it has less performance + impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl. + * a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but + apply it in each activation. This is slightly slower, numerically different, but matches official impl. + * skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput + for what it is/does. Approx 8-10% throughput loss. + """ + def __init__( + self, + cfg: NfCfg, + num_classes: int = 1000, + in_chans: int = 3, + global_pool: str = 'avg', + output_stride: int = 32, + drop_rate: float = 0., + drop_path_rate: float = 0., + **kwargs, + ): + """ + Args: + cfg: Model architecture configuration. + num_classes: Number of classifier classes. + in_chans: Number of input channels. + global_pool: Global pooling type. + output_stride: Output stride of network, one of (8, 16, 32). + drop_rate: Dropout rate. + drop_path_rate: Stochastic depth drop-path rate. + **kwargs: Extra kwargs overlayed onto cfg. + """ + super().__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + self.grad_checkpointing = False + + cfg = replace(cfg, **kwargs) + assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})." + conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d + if cfg.gamma_in_act: + act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer]) + conv_layer = partial(conv_layer, eps=cfg.std_conv_eps) + else: + act_layer = get_act_layer(cfg.act_layer) + conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps) + attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None + + stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div) + self.stem, stem_stride, stem_feat = create_stem( + in_chans, + stem_chs, + cfg.stem_type, + conv_layer=conv_layer, + act_layer=act_layer, + ) + + self.feature_info = [stem_feat] + drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + prev_chs = stem_chs + net_stride = stem_stride + dilation = 1 + expected_var = 1.0 + stages = [] + for stage_idx, stage_depth in enumerate(cfg.depths): + stride = 1 if stage_idx == 0 and stem_stride > 2 else 2 + if net_stride >= output_stride and stride > 1: + dilation *= stride + stride = 1 + net_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + + blocks = [] + for block_idx in range(cfg.depths[stage_idx]): + first_block = block_idx == 0 and stage_idx == 0 + out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div) + blocks += [NormFreeBlock( + in_chs=prev_chs, out_chs=out_chs, + alpha=cfg.alpha, + beta=1. / expected_var ** 0.5, + stride=stride if block_idx == 0 else 1, + dilation=dilation, + first_dilation=first_dilation, + group_size=cfg.group_size, + bottle_ratio=1. if cfg.reg and first_block else cfg.bottle_ratio, + ch_div=cfg.ch_div, + reg=cfg.reg, + extra_conv=cfg.extra_conv, + skipinit=cfg.skipinit, + attn_layer=attn_layer, + attn_gain=cfg.attn_gain, + act_layer=act_layer, + conv_layer=conv_layer, + drop_path_rate=drop_path_rates[stage_idx][block_idx], + )] + if block_idx == 0: + expected_var = 1. # expected var is reset after first block of each stage + expected_var += cfg.alpha ** 2 # Even if reset occurs, increment expected variance + first_dilation = dilation + prev_chs = out_chs + self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')] + stages += [nn.Sequential(*blocks)] + self.stages = nn.Sequential(*stages) + + if cfg.num_features: + # The paper NFRegNet models have an EfficientNet-like final head convolution. + self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div) + self.final_conv = conv_layer(prev_chs, self.num_features, 1) + self.feature_info[-1] = dict(num_chs=self.num_features, reduction=net_stride, module=f'final_conv') + else: + self.num_features = prev_chs + self.final_conv = nn.Identity() + self.final_act = act_layer(inplace=cfg.num_features > 0) + + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + ) + + for n, m in self.named_modules(): + if 'fc' in n and isinstance(m, nn.Linear): + if cfg.zero_init_fc: + nn.init.zeros_(m.weight) + else: + nn.init.normal_(m.weight, 0., .01) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear') + if m.bias is not None: + nn.init.zeros_(m.bias) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', + blocks=[ + (r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None), + (r'^final_conv', (99999,)) + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head.reset(num_classes, global_pool) + + def forward_features(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + x = self.final_conv(x) + x = self.final_act(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _nfres_cfg( + depths, + channels=(256, 512, 1024, 2048), + group_size=None, + act_layer='relu', + attn_layer=None, + attn_kwargs=None, +): + attn_kwargs = attn_kwargs or {} + cfg = NfCfg( + depths=depths, + channels=channels, + stem_type='7x7_pool', + stem_chs=64, + bottle_ratio=0.25, + group_size=group_size, + act_layer=act_layer, + attn_layer=attn_layer, + attn_kwargs=attn_kwargs, + ) + return cfg + + +def _nfreg_cfg(depths, channels=(48, 104, 208, 440)): + num_features = 1280 * channels[-1] // 440 + attn_kwargs = dict(rd_ratio=0.5) + cfg = NfCfg( + depths=depths, + channels=channels, + stem_type='3x3', + group_size=8, + width_factor=0.75, + bottle_ratio=2.25, + num_features=num_features, + reg=True, + attn_layer='se', + attn_kwargs=attn_kwargs, + ) + return cfg + + +def _nfnet_cfg( + depths, + channels=(256, 512, 1536, 1536), + group_size=128, + bottle_ratio=0.5, + feat_mult=2., + act_layer='gelu', + attn_layer='se', + attn_kwargs=None, +): + num_features = int(channels[-1] * feat_mult) + attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5) + cfg = NfCfg( + depths=depths, + channels=channels, + stem_type='deep_quad', + stem_chs=128, + group_size=group_size, + bottle_ratio=bottle_ratio, + extra_conv=True, + num_features=num_features, + act_layer=act_layer, + attn_layer=attn_layer, + attn_kwargs=attn_kwargs, + ) + return cfg + + +def _dm_nfnet_cfg( + depths, + channels=(256, 512, 1536, 1536), + act_layer='gelu', + skipinit=True, +): + cfg = NfCfg( + depths=depths, + channels=channels, + stem_type='deep_quad', + stem_chs=128, + group_size=128, + bottle_ratio=0.5, + extra_conv=True, + gamma_in_act=True, + same_padding=True, + skipinit=skipinit, + num_features=int(channels[-1] * 2.0), + act_layer=act_layer, + attn_layer='se', + attn_kwargs=dict(rd_ratio=0.5), + ) + return cfg + + +model_cfgs = dict( + # NFNet-F models w/ GELU compatible with DeepMind weights + dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)), + dm_nfnet_f1=_dm_nfnet_cfg(depths=(2, 4, 12, 6)), + dm_nfnet_f2=_dm_nfnet_cfg(depths=(3, 6, 18, 9)), + dm_nfnet_f3=_dm_nfnet_cfg(depths=(4, 8, 24, 12)), + dm_nfnet_f4=_dm_nfnet_cfg(depths=(5, 10, 30, 15)), + dm_nfnet_f5=_dm_nfnet_cfg(depths=(6, 12, 36, 18)), + dm_nfnet_f6=_dm_nfnet_cfg(depths=(7, 14, 42, 21)), + + # NFNet-F models w/ GELU + nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)), + nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)), + nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)), + nfnet_f3=_nfnet_cfg(depths=(4, 8, 24, 12)), + nfnet_f4=_nfnet_cfg(depths=(5, 10, 30, 15)), + nfnet_f5=_nfnet_cfg(depths=(6, 12, 36, 18)), + nfnet_f6=_nfnet_cfg(depths=(7, 14, 42, 21)), + nfnet_f7=_nfnet_cfg(depths=(8, 16, 48, 24)), + + # Experimental 'light' versions of NFNet-F that are little leaner, w/ SiLU act + nfnet_l0=_nfnet_cfg( + depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25, + attn_kwargs=dict(rd_ratio=0.25, rd_divisor=8), act_layer='silu'), + eca_nfnet_l0=_nfnet_cfg( + depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l1=_nfnet_cfg( + depths=(2, 4, 12, 6), feat_mult=2, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l2=_nfnet_cfg( + depths=(3, 6, 18, 9), feat_mult=2, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l3=_nfnet_cfg( + depths=(4, 8, 24, 12), feat_mult=2, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + + # EffNet influenced RegNet defs. + # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8. + nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)), + nf_regnet_b1=_nfreg_cfg(depths=(2, 4, 7, 7)), + nf_regnet_b2=_nfreg_cfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488)), + nf_regnet_b3=_nfreg_cfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528)), + nf_regnet_b4=_nfreg_cfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616)), + nf_regnet_b5=_nfreg_cfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704)), + + # ResNet (preact, D style deep stem/avg down) defs + nf_resnet26=_nfres_cfg(depths=(2, 2, 2, 2)), + nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)), + nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)), + + nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)), + nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)), + nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)), + + nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()), + nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()), + nf_ecaresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='eca', attn_kwargs=dict()), +) + + +def _create_normfreenet(variant, pretrained=False, **kwargs): + model_cfg = model_cfgs[variant] + feature_cfg = dict(flatten_sequential=True) + return build_model_with_cfg( + NormFreeNet, + variant, + pretrained, + model_cfg=model_cfg, + feature_cfg=feature_cfg, + **kwargs, + ) + + +def _dcfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'dm_nfnet_f0.dm_in1k': _dcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f0-604f9c3a.pth', + pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), crop_pct=.9, crop_mode='squash'), + 'dm_nfnet_f1.dm_in1k': _dcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f1-fc540f82.pth', + pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), crop_pct=0.91, crop_mode='squash'), + 'dm_nfnet_f2.dm_in1k': _dcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f2-89875923.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), crop_pct=0.92, crop_mode='squash'), + 'dm_nfnet_f3.dm_in1k': _dcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f3-d74ab3aa.pth', + pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), crop_pct=0.94, crop_mode='squash'), + 'dm_nfnet_f4.dm_in1k': _dcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f4-0ac5b10b.pth', + pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), crop_pct=0.951, crop_mode='squash'), + 'dm_nfnet_f5.dm_in1k': _dcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f5-ecb20ab1.pth', + pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), crop_pct=0.954, crop_mode='squash'), + 'dm_nfnet_f6.dm_in1k': _dcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f6-e0f12116.pth', + pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), crop_pct=0.956, crop_mode='squash'), + + 'nfnet_f0': _dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), + 'nfnet_f1': _dcfg( + url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)), + 'nfnet_f2': _dcfg( + url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)), + 'nfnet_f3': _dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)), + 'nfnet_f4': _dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)), + 'nfnet_f5': _dcfg( + url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)), + 'nfnet_f6': _dcfg( + url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)), + 'nfnet_f7': _dcfg( + url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)), + + 'nfnet_l0.ra2_in1k': _dcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nfnet_l0_ra2-45c6688d.pth', + pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'eca_nfnet_l0.ra2_in1k': _dcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l0_ra2-e3e9ac50.pth', + pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'eca_nfnet_l1.ra2_in1k': _dcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), test_crop_pct=1.0), + 'eca_nfnet_l2.ra3_in1k': _dcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l2_ra3-da781a61.pth', + pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), test_crop_pct=1.0), + 'eca_nfnet_l3': _dcfg( + url='', + pool_size=(11, 11), input_size=(3, 352, 352), test_input_size=(3, 448, 448), test_crop_pct=1.0), + + 'nf_regnet_b0': _dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'), + 'nf_regnet_b1.ra2_in1k': _dcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), first_conv='stem.conv'), # NOT to paper spec + 'nf_regnet_b2': _dcfg( + url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272), first_conv='stem.conv'), + 'nf_regnet_b3': _dcfg( + url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320), first_conv='stem.conv'), + 'nf_regnet_b4': _dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), first_conv='stem.conv'), + 'nf_regnet_b5': _dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456), first_conv='stem.conv'), + + 'nf_resnet26': _dcfg(url='', first_conv='stem.conv'), + 'nf_resnet50.ra2_in1k': _dcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94, first_conv='stem.conv'), + 'nf_resnet101': _dcfg(url='', first_conv='stem.conv'), + + 'nf_seresnet26': _dcfg(url='', first_conv='stem.conv'), + 'nf_seresnet50': _dcfg(url='', first_conv='stem.conv'), + 'nf_seresnet101': _dcfg(url='', first_conv='stem.conv'), + + 'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'), + 'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'), + 'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'), +}) + + +@register_model +def dm_nfnet_f0(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-F0 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f0', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f1(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-F1 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f1', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f2(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-F2 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f2', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f3(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-F3 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f3', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f4(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-F4 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f4', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f5(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-F5 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f5', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f6(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-F6 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f6', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f0(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-F0 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f0', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f1(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-F1 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f1', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f2(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-F2 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f2', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f3(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-F3 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f4(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-F4 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f5(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-F5 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f6(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-F6 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f7(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-F7 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_l0(pretrained=False, **kwargs) -> NormFreeNet: + """ NFNet-L0b w/ SiLU + My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio + """ + return _create_normfreenet('nfnet_l0', pretrained=pretrained, **kwargs) + + +@register_model +def eca_nfnet_l0(pretrained=False, **kwargs) -> NormFreeNet: + """ ECA-NFNet-L0 w/ SiLU + My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l0', pretrained=pretrained, **kwargs) + + +@register_model +def eca_nfnet_l1(pretrained=False, **kwargs) -> NormFreeNet: + """ ECA-NFNet-L1 w/ SiLU + My experimental 'light' model w/ F1 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l1', pretrained=pretrained, **kwargs) + + +@register_model +def eca_nfnet_l2(pretrained=False, **kwargs) -> NormFreeNet: + """ ECA-NFNet-L2 w/ SiLU + My experimental 'light' model w/ F2 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l2', pretrained=pretrained, **kwargs) + + +@register_model +def eca_nfnet_l3(pretrained=False, **kwargs) -> NormFreeNet: + """ ECA-NFNet-L3 w/ SiLU + My experimental 'light' model w/ F3 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l3', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b0(pretrained=False, **kwargs) -> NormFreeNet: + """ Normalization-Free RegNet-B0 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b1(pretrained=False, **kwargs) -> NormFreeNet: + """ Normalization-Free RegNet-B1 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b2(pretrained=False, **kwargs) -> NormFreeNet: + """ Normalization-Free RegNet-B2 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b3(pretrained=False, **kwargs) -> NormFreeNet: + """ Normalization-Free RegNet-B3 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b4(pretrained=False, **kwargs) -> NormFreeNet: + """ Normalization-Free RegNet-B4 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b5(pretrained=False, **kwargs) -> NormFreeNet: + """ Normalization-Free RegNet-B5 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs) + + +@register_model +def nf_resnet26(pretrained=False, **kwargs) -> NormFreeNet: + """ Normalization-Free ResNet-26 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_resnet26', pretrained=pretrained, **kwargs) + + +@register_model +def nf_resnet50(pretrained=False, **kwargs) -> NormFreeNet: + """ Normalization-Free ResNet-50 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_resnet50', pretrained=pretrained, **kwargs) + + +@register_model +def nf_resnet101(pretrained=False, **kwargs) -> NormFreeNet: + """ Normalization-Free ResNet-101 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_resnet101', pretrained=pretrained, **kwargs) + + +@register_model +def nf_seresnet26(pretrained=False, **kwargs) -> NormFreeNet: + """ Normalization-Free SE-ResNet26 + """ + return _create_normfreenet('nf_seresnet26', pretrained=pretrained, **kwargs) + + +@register_model +def nf_seresnet50(pretrained=False, **kwargs) -> NormFreeNet: + """ Normalization-Free SE-ResNet50 + """ + return _create_normfreenet('nf_seresnet50', pretrained=pretrained, **kwargs) + + +@register_model +def nf_seresnet101(pretrained=False, **kwargs) -> NormFreeNet: + """ Normalization-Free SE-ResNet101 + """ + return _create_normfreenet('nf_seresnet101', pretrained=pretrained, **kwargs) + + +@register_model +def nf_ecaresnet26(pretrained=False, **kwargs) -> NormFreeNet: + """ Normalization-Free ECA-ResNet26 + """ + return _create_normfreenet('nf_ecaresnet26', pretrained=pretrained, **kwargs) + + +@register_model +def nf_ecaresnet50(pretrained=False, **kwargs) -> NormFreeNet: + """ Normalization-Free ECA-ResNet50 + """ + return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs) + + +@register_model +def nf_ecaresnet101(pretrained=False, **kwargs) -> NormFreeNet: + """ Normalization-Free ECA-ResNet101 + """ + return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/regnet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/regnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4ece9f4c017aada4a41b92f48aae312508d4fb6c --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/regnet.py @@ -0,0 +1,1122 @@ +"""RegNet X, Y, Z, and more + +Paper: `Designing Network Design Spaces` - https://arxiv.org/abs/2003.13678 +Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py + +Paper: `Fast and Accurate Model Scaling` - https://arxiv.org/abs/2103.06877 +Original Impl: None + +Based on original PyTorch impl linked above, but re-wrote to use my own blocks (adapted from ResNet here) +and cleaned up with more descriptive variable names. + +Weights from original pycls impl have been modified: +* first layer from BGR -> RGB as most PyTorch models are +* removed training specific dict entries from checkpoints and keep model state_dict only +* remap names to match the ones here + +Supports weight loading from torchvision and classy-vision (incl VISSL SEER) + +A number of custom timm model definitions additions including: +* stochastic depth, gradient checkpointing, layer-decay, configurable dilation +* a pre-activation 'V' variant +* only known RegNet-Z model definitions with pretrained weights + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +from dataclasses import dataclass, replace +from functools import partial +from typing import Optional, Union, Callable + +import numpy as np +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct +from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d, make_divisible +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq, named_apply +from ._registry import generate_default_cfgs, register_model, register_model_deprecations + +__all__ = ['RegNet', 'RegNetCfg'] # model_registry will add each entrypoint fn to this + + +@dataclass +class RegNetCfg: + depth: int = 21 + w0: int = 80 + wa: float = 42.63 + wm: float = 2.66 + group_size: int = 24 + bottle_ratio: float = 1. + se_ratio: float = 0. + group_min_ratio: float = 0. + stem_width: int = 32 + downsample: Optional[str] = 'conv1x1' + linear_out: bool = False + preact: bool = False + num_features: int = 0 + act_layer: Union[str, Callable] = 'relu' + norm_layer: Union[str, Callable] = 'batchnorm' + + +def quantize_float(f, q): + """Converts a float to the closest non-zero int divisible by q.""" + return int(round(f / q) * q) + + +def adjust_widths_groups_comp(widths, bottle_ratios, groups, min_ratio=0.): + """Adjusts the compatibility of widths and groups.""" + bottleneck_widths = [int(w * b) for w, b in zip(widths, bottle_ratios)] + groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_widths)] + if min_ratio: + # torchvision uses a different rounding scheme for ensuring bottleneck widths divisible by group widths + bottleneck_widths = [make_divisible(w_bot, g, min_ratio) for w_bot, g in zip(bottleneck_widths, groups)] + else: + bottleneck_widths = [quantize_float(w_bot, g) for w_bot, g in zip(bottleneck_widths, groups)] + widths = [int(w_bot / b) for w_bot, b in zip(bottleneck_widths, bottle_ratios)] + return widths, groups + + +def generate_regnet(width_slope, width_initial, width_mult, depth, group_size, quant=8): + """Generates per block widths from RegNet parameters.""" + assert width_slope >= 0 and width_initial > 0 and width_mult > 1 and width_initial % quant == 0 + # TODO dWr scaling? + # depth = int(depth * (scale ** 0.1)) + # width_scale = scale ** 0.4 # dWr scale, exp 0.8 / 2, applied to both group and layer widths + widths_cont = np.arange(depth) * width_slope + width_initial + width_exps = np.round(np.log(widths_cont / width_initial) / np.log(width_mult)) + widths = np.round(np.divide(width_initial * np.power(width_mult, width_exps), quant)) * quant + num_stages, max_stage = len(np.unique(widths)), width_exps.max() + 1 + groups = np.array([group_size for _ in range(num_stages)]) + return widths.astype(int).tolist(), num_stages, groups.astype(int).tolist() + + +def downsample_conv( + in_chs, + out_chs, + kernel_size=1, + stride=1, + dilation=1, + norm_layer=None, + preact=False, +): + norm_layer = norm_layer or nn.BatchNorm2d + kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size + dilation = dilation if kernel_size > 1 else 1 + if preact: + return create_conv2d( + in_chs, + out_chs, + kernel_size, + stride=stride, + dilation=dilation, + ) + else: + return ConvNormAct( + in_chs, + out_chs, + kernel_size, + stride=stride, + dilation=dilation, + norm_layer=norm_layer, + apply_act=False, + ) + + +def downsample_avg( + in_chs, + out_chs, + kernel_size=1, + stride=1, + dilation=1, + norm_layer=None, + preact=False, +): + """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" + norm_layer = norm_layer or nn.BatchNorm2d + avg_stride = stride if dilation == 1 else 1 + pool = nn.Identity() + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + if preact: + conv = create_conv2d(in_chs, out_chs, 1, stride=1) + else: + conv = ConvNormAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, apply_act=False) + return nn.Sequential(*[pool, conv]) + + +def create_shortcut( + downsample_type, + in_chs, + out_chs, + kernel_size, + stride, + dilation=(1, 1), + norm_layer=None, + preact=False, +): + assert downsample_type in ('avg', 'conv1x1', '', None) + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact) + if not downsample_type: + return None # no shortcut, no downsample + elif downsample_type == 'avg': + return downsample_avg(in_chs, out_chs, **dargs) + else: + return downsample_conv(in_chs, out_chs, kernel_size=kernel_size, **dargs) + else: + return nn.Identity() # identity shortcut (no downsample) + + +class Bottleneck(nn.Module): + """ RegNet Bottleneck + + This is almost exactly the same as a ResNet Bottlneck. The main difference is the SE block is moved from + after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels. + """ + + def __init__( + self, + in_chs, + out_chs, + stride=1, + dilation=(1, 1), + bottle_ratio=1, + group_size=1, + se_ratio=0.25, + downsample='conv1x1', + linear_out=False, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + drop_block=None, + drop_path_rate=0., + ): + super(Bottleneck, self).__init__() + act_layer = get_act_layer(act_layer) + bottleneck_chs = int(round(out_chs * bottle_ratio)) + groups = bottleneck_chs // group_size + + cargs = dict(act_layer=act_layer, norm_layer=norm_layer) + self.conv1 = ConvNormAct(in_chs, bottleneck_chs, kernel_size=1, **cargs) + self.conv2 = ConvNormAct( + bottleneck_chs, + bottleneck_chs, + kernel_size=3, + stride=stride, + dilation=dilation[0], + groups=groups, + drop_layer=drop_block, + **cargs, + ) + if se_ratio: + se_channels = int(round(in_chs * se_ratio)) + self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer) + else: + self.se = nn.Identity() + self.conv3 = ConvNormAct(bottleneck_chs, out_chs, kernel_size=1, apply_act=False, **cargs) + self.act3 = nn.Identity() if linear_out else act_layer() + self.downsample = create_shortcut( + downsample, + in_chs, + out_chs, + kernel_size=1, + stride=stride, + dilation=dilation, + norm_layer=norm_layer, + ) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + + def zero_init_last(self): + nn.init.zeros_(self.conv3.bn.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.conv2(x) + x = self.se(x) + x = self.conv3(x) + if self.downsample is not None: + # NOTE stuck with downsample as the attr name due to weight compatibility + # now represents the shortcut, no shortcut if None, and non-downsample shortcut == nn.Identity() + x = self.drop_path(x) + self.downsample(shortcut) + x = self.act3(x) + return x + + +class PreBottleneck(nn.Module): + """ RegNet Bottleneck + + This is almost exactly the same as a ResNet Bottlneck. The main difference is the SE block is moved from + after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels. + """ + + def __init__( + self, + in_chs, + out_chs, + stride=1, + dilation=(1, 1), + bottle_ratio=1, + group_size=1, + se_ratio=0.25, + downsample='conv1x1', + linear_out=False, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + drop_block=None, + drop_path_rate=0., + ): + super(PreBottleneck, self).__init__() + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + bottleneck_chs = int(round(out_chs * bottle_ratio)) + groups = bottleneck_chs // group_size + + self.norm1 = norm_act_layer(in_chs) + self.conv1 = create_conv2d(in_chs, bottleneck_chs, kernel_size=1) + self.norm2 = norm_act_layer(bottleneck_chs) + self.conv2 = create_conv2d( + bottleneck_chs, + bottleneck_chs, + kernel_size=3, + stride=stride, + dilation=dilation[0], + groups=groups, + ) + if se_ratio: + se_channels = int(round(in_chs * se_ratio)) + self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer) + else: + self.se = nn.Identity() + self.norm3 = norm_act_layer(bottleneck_chs) + self.conv3 = create_conv2d(bottleneck_chs, out_chs, kernel_size=1) + self.downsample = create_shortcut( + downsample, + in_chs, + out_chs, + kernel_size=1, + stride=stride, + dilation=dilation, + preact=True, + ) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + + def zero_init_last(self): + pass + + def forward(self, x): + x = self.norm1(x) + shortcut = x + x = self.conv1(x) + x = self.norm2(x) + x = self.conv2(x) + x = self.se(x) + x = self.norm3(x) + x = self.conv3(x) + if self.downsample is not None: + # NOTE stuck with downsample as the attr name due to weight compatibility + # now represents the shortcut, no shortcut if None, and non-downsample shortcut == nn.Identity() + x = self.drop_path(x) + self.downsample(shortcut) + return x + + +class RegStage(nn.Module): + """Stage (sequence of blocks w/ the same output shape).""" + + def __init__( + self, + depth, + in_chs, + out_chs, + stride, + dilation, + drop_path_rates=None, + block_fn=Bottleneck, + **block_kwargs, + ): + super(RegStage, self).__init__() + self.grad_checkpointing = False + + first_dilation = 1 if dilation in (1, 2) else 2 + for i in range(depth): + block_stride = stride if i == 0 else 1 + block_in_chs = in_chs if i == 0 else out_chs + block_dilation = (first_dilation, dilation) + dpr = drop_path_rates[i] if drop_path_rates is not None else 0. + name = "b{}".format(i + 1) + self.add_module( + name, + block_fn( + block_in_chs, + out_chs, + stride=block_stride, + dilation=block_dilation, + drop_path_rate=dpr, + **block_kwargs, + ) + ) + first_dilation = dilation + + def forward(self, x): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.children(), x) + else: + for block in self.children(): + x = block(x) + return x + + +class RegNet(nn.Module): + """RegNet-X, Y, and Z Models + + Paper: https://arxiv.org/abs/2003.13678 + Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py + """ + + def __init__( + self, + cfg: RegNetCfg, + in_chans=3, + num_classes=1000, + output_stride=32, + global_pool='avg', + drop_rate=0., + drop_path_rate=0., + zero_init_last=True, + **kwargs, + ): + """ + + Args: + cfg (RegNetCfg): Model architecture configuration + in_chans (int): Number of input channels (default: 3) + num_classes (int): Number of classifier classes (default: 1000) + output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) + global_pool (str): Global pooling type (default: 'avg') + drop_rate (float): Dropout rate (default: 0.) + drop_path_rate (float): Stochastic depth drop-path rate (default: 0.) + zero_init_last (bool): Zero-init last weight of residual path + kwargs (dict): Extra kwargs overlayed onto cfg + """ + super().__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + assert output_stride in (8, 16, 32) + cfg = replace(cfg, **kwargs) # update cfg with extra passed kwargs + + # Construct the stem + stem_width = cfg.stem_width + na_args = dict(act_layer=cfg.act_layer, norm_layer=cfg.norm_layer) + if cfg.preact: + self.stem = create_conv2d(in_chans, stem_width, 3, stride=2) + else: + self.stem = ConvNormAct(in_chans, stem_width, 3, stride=2, **na_args) + self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')] + + # Construct the stages + prev_width = stem_width + curr_stride = 2 + per_stage_args, common_args = self._get_stage_args( + cfg, + output_stride=output_stride, + drop_path_rate=drop_path_rate, + ) + assert len(per_stage_args) == 4 + block_fn = PreBottleneck if cfg.preact else Bottleneck + for i, stage_args in enumerate(per_stage_args): + stage_name = "s{}".format(i + 1) + self.add_module( + stage_name, + RegStage( + in_chs=prev_width, + block_fn=block_fn, + **stage_args, + **common_args, + ) + ) + prev_width = stage_args['out_chs'] + curr_stride *= stage_args['stride'] + self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)] + + # Construct the head + if cfg.num_features: + self.final_conv = ConvNormAct(prev_width, cfg.num_features, kernel_size=1, **na_args) + self.num_features = cfg.num_features + else: + final_act = cfg.linear_out or cfg.preact + self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity() + self.num_features = prev_width + self.head = ClassifierHead( + in_features=self.num_features, + num_classes=num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + ) + + named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) + + def _get_stage_args(self, cfg: RegNetCfg, default_stride=2, output_stride=32, drop_path_rate=0.): + # Generate RegNet ws per block + widths, num_stages, stage_gs = generate_regnet(cfg.wa, cfg.w0, cfg.wm, cfg.depth, cfg.group_size) + + # Convert to per stage format + stage_widths, stage_depths = np.unique(widths, return_counts=True) + stage_br = [cfg.bottle_ratio for _ in range(num_stages)] + stage_strides = [] + stage_dilations = [] + net_stride = 2 + dilation = 1 + for _ in range(num_stages): + if net_stride >= output_stride: + dilation *= default_stride + stride = 1 + else: + stride = default_stride + net_stride *= stride + stage_strides.append(stride) + stage_dilations.append(dilation) + stage_dpr = np.split(np.linspace(0, drop_path_rate, sum(stage_depths)), np.cumsum(stage_depths[:-1])) + + # Adjust the compatibility of ws and gws + stage_widths, stage_gs = adjust_widths_groups_comp( + stage_widths, stage_br, stage_gs, min_ratio=cfg.group_min_ratio) + arg_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates'] + per_stage_args = [ + dict(zip(arg_names, params)) for params in + zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr) + ] + common_args = dict( + downsample=cfg.downsample, + se_ratio=cfg.se_ratio, + linear_out=cfg.linear_out, + act_layer=cfg.act_layer, + norm_layer=cfg.norm_layer, + ) + return per_stage_args, common_args + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^s(\d+)' if coarse else r'^s(\d+)\.b(\d+)', + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in list(self.children())[1:-1]: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head.reset(num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.stem(x) + x = self.s1(x) + x = self.s2(x) + x = self.s3(x) + x = self.s4(x) + x = self.final_conv(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _init_weights(module, name='', zero_init_last=False): + if isinstance(module, nn.Conv2d): + fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels + fan_out //= module.groups + module.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=0.01) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif zero_init_last and hasattr(module, 'zero_init_last'): + module.zero_init_last() + + +def _filter_fn(state_dict): + state_dict = state_dict.get('model', state_dict) + replaces = [ + ('f.a.0', 'conv1.conv'), + ('f.a.1', 'conv1.bn'), + ('f.b.0', 'conv2.conv'), + ('f.b.1', 'conv2.bn'), + ('f.final_bn', 'conv3.bn'), + ('f.se.excitation.0', 'se.fc1'), + ('f.se.excitation.2', 'se.fc2'), + ('f.se', 'se'), + ('f.c.0', 'conv3.conv'), + ('f.c.1', 'conv3.bn'), + ('f.c', 'conv3.conv'), + ('proj.0', 'downsample.conv'), + ('proj.1', 'downsample.bn'), + ('proj', 'downsample.conv'), + ] + if 'classy_state_dict' in state_dict: + # classy-vision & vissl (SEER) weights + import re + state_dict = state_dict['classy_state_dict']['base_model']['model'] + out = {} + for k, v in state_dict['trunk'].items(): + k = k.replace('_feature_blocks.conv1.stem.0', 'stem.conv') + k = k.replace('_feature_blocks.conv1.stem.1', 'stem.bn') + k = re.sub( + r'^_feature_blocks.res\d.block(\d)-(\d+)', + lambda x: f's{int(x.group(1))}.b{int(x.group(2)) + 1}', k) + k = re.sub(r's(\d)\.b(\d+)\.bn', r's\1.b\2.downsample.bn', k) + for s, r in replaces: + k = k.replace(s, r) + out[k] = v + for k, v in state_dict['heads'].items(): + if 'projection_head' in k or 'prototypes' in k: + continue + k = k.replace('0.clf.0', 'head.fc') + out[k] = v + return out + if 'stem.0.weight' in state_dict: + # torchvision weights + import re + out = {} + for k, v in state_dict.items(): + k = k.replace('stem.0', 'stem.conv') + k = k.replace('stem.1', 'stem.bn') + k = re.sub( + r'trunk_output.block(\d)\.block(\d+)\-(\d+)', + lambda x: f's{int(x.group(1))}.b{int(x.group(3)) + 1}', k) + for s, r in replaces: + k = k.replace(s, r) + k = k.replace('fc.', 'head.fc.') + out[k] = v + return out + return state_dict + + +# Model FLOPS = three trailing digits * 10^8 +model_cfgs = dict( + # RegNet-X + regnetx_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13), + regnetx_004=RegNetCfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22), + regnetx_004_tv=RegNetCfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22, group_min_ratio=0.9), + regnetx_006=RegNetCfg(w0=48, wa=36.97, wm=2.24, group_size=24, depth=16), + regnetx_008=RegNetCfg(w0=56, wa=35.73, wm=2.28, group_size=16, depth=16), + regnetx_016=RegNetCfg(w0=80, wa=34.01, wm=2.25, group_size=24, depth=18), + regnetx_032=RegNetCfg(w0=88, wa=26.31, wm=2.25, group_size=48, depth=25), + regnetx_040=RegNetCfg(w0=96, wa=38.65, wm=2.43, group_size=40, depth=23), + regnetx_064=RegNetCfg(w0=184, wa=60.83, wm=2.07, group_size=56, depth=17), + regnetx_080=RegNetCfg(w0=80, wa=49.56, wm=2.88, group_size=120, depth=23), + regnetx_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19), + regnetx_160=RegNetCfg(w0=216, wa=55.59, wm=2.1, group_size=128, depth=22), + regnetx_320=RegNetCfg(w0=320, wa=69.86, wm=2.0, group_size=168, depth=23), + + # RegNet-Y + regnety_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13, se_ratio=0.25), + regnety_004=RegNetCfg(w0=48, wa=27.89, wm=2.09, group_size=8, depth=16, se_ratio=0.25), + regnety_006=RegNetCfg(w0=48, wa=32.54, wm=2.32, group_size=16, depth=15, se_ratio=0.25), + regnety_008=RegNetCfg(w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25), + regnety_008_tv=RegNetCfg(w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25, group_min_ratio=0.9), + regnety_016=RegNetCfg(w0=48, wa=20.71, wm=2.65, group_size=24, depth=27, se_ratio=0.25), + regnety_032=RegNetCfg(w0=80, wa=42.63, wm=2.66, group_size=24, depth=21, se_ratio=0.25), + regnety_040=RegNetCfg(w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25), + regnety_064=RegNetCfg(w0=112, wa=33.22, wm=2.27, group_size=72, depth=25, se_ratio=0.25), + regnety_080=RegNetCfg(w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25), + regnety_080_tv=RegNetCfg(w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25, group_min_ratio=0.9), + regnety_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25), + regnety_160=RegNetCfg(w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25), + regnety_320=RegNetCfg(w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25), + regnety_640=RegNetCfg(w0=352, wa=147.48, wm=2.4, group_size=328, depth=20, se_ratio=0.25), + regnety_1280=RegNetCfg(w0=456, wa=160.83, wm=2.52, group_size=264, depth=27, se_ratio=0.25), + regnety_2560=RegNetCfg(w0=640, wa=230.83, wm=2.53, group_size=373, depth=27, se_ratio=0.25), + #regnety_2560=RegNetCfg(w0=640, wa=124.47, wm=2.04, group_size=848, depth=27, se_ratio=0.25), + + # Experimental + regnety_040_sgn=RegNetCfg( + w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25, + act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)), + + # regnetv = 'preact regnet y' + regnetv_040=RegNetCfg( + depth=22, w0=96, wa=31.41, wm=2.24, group_size=64, se_ratio=0.25, preact=True, act_layer='silu'), + regnetv_064=RegNetCfg( + depth=25, w0=112, wa=33.22, wm=2.27, group_size=72, se_ratio=0.25, preact=True, act_layer='silu', + downsample='avg'), + + # RegNet-Z (unverified) + regnetz_005=RegNetCfg( + depth=21, w0=16, wa=10.7, wm=2.51, group_size=4, bottle_ratio=4.0, se_ratio=0.25, + downsample=None, linear_out=True, num_features=1024, act_layer='silu', + ), + regnetz_040=RegNetCfg( + depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25, + downsample=None, linear_out=True, num_features=0, act_layer='silu', + ), + regnetz_040_h=RegNetCfg( + depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25, + downsample=None, linear_out=True, num_features=1536, act_layer='silu', + ), +) + + +def _create_regnet(variant, pretrained, **kwargs): + return build_model_with_cfg( + RegNet, variant, pretrained, + model_cfg=model_cfgs[variant], + pretrained_filter_fn=_filter_fn, + **kwargs) + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'test_input_size': (3, 288, 288), 'crop_pct': 0.95, 'test_crop_pct': 1.0, + 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + **kwargs + } + + +def _cfgpyc(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + 'license': 'mit', 'origin_url': 'https://github.com/facebookresearch/pycls', **kwargs + } + + +def _cfgtv2(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.965, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + 'license': 'bsd-3-clause', 'origin_url': 'https://github.com/pytorch/vision', **kwargs + } + + +default_cfgs = generate_default_cfgs({ + # timm trained models + 'regnety_032.ra_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth'), + 'regnety_040.ra3_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_040_ra3-670e1166.pth'), + 'regnety_064.ra3_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_064_ra3-aa26dc7d.pth'), + 'regnety_080.ra3_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_080_ra3-1fdc4344.pth'), + 'regnety_120.sw_in12k_ft_in1k': _cfg(hf_hub_id='timm/'), + 'regnety_160.sw_in12k_ft_in1k': _cfg(hf_hub_id='timm/'), + 'regnety_160.lion_in12k_ft_in1k': _cfg(hf_hub_id='timm/'), + + # timm in12k pretrain + 'regnety_120.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + 'regnety_160.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + + # timm custom arch (v and z guess) + trained models + 'regnety_040_sgn.untrained': _cfg(url=''), + 'regnetv_040.ra3_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_040_ra3-c248f51f.pth', + first_conv='stem'), + 'regnetv_064.ra3_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_064_ra3-530616c2.pth', + first_conv='stem'), + + 'regnetz_005.untrained': _cfg(url=''), + 'regnetz_040.ra3_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040_ra3-9007edf5.pth', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)), + 'regnetz_040_h.ra3_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040h_ra3-f594343b.pth', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)), + + # used in DeiT for distillation (from Facebook DeiT GitHub repository) + 'regnety_160.deit_in1k': _cfg( + hf_hub_id='timm/', url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth'), + + 'regnetx_004_tv.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth'), + 'regnetx_008.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth'), + 'regnetx_016.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth'), + 'regnetx_032.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth'), + 'regnetx_080.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth'), + 'regnetx_160.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth'), + 'regnetx_320.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth'), + + 'regnety_004.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth'), + 'regnety_008_tv.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth'), + 'regnety_016.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth'), + 'regnety_032.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth'), + 'regnety_080_tv.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth'), + 'regnety_160.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth'), + 'regnety_320.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth'), + + 'regnety_160.swag_ft_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth', license='cc-by-nc-4.0', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'regnety_320.swag_ft_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth', license='cc-by-nc-4.0', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'regnety_1280.swag_ft_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth', license='cc-by-nc-4.0', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + + 'regnety_160.swag_lc_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_16gf_lc_swag-f3ec0043.pth', license='cc-by-nc-4.0'), + 'regnety_320.swag_lc_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_32gf_lc_swag-e1583746.pth', license='cc-by-nc-4.0'), + 'regnety_1280.swag_lc_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_128gf_lc_swag-cbe8ce12.pth', license='cc-by-nc-4.0'), + + 'regnety_320.seer_ft_in1k': _cfgtv2( + hf_hub_id='timm/', + license='other', origin_url='https://github.com/facebookresearch/vissl', + url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet32_finetuned_in1k_model_final_checkpoint_phase78.torch', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'regnety_640.seer_ft_in1k': _cfgtv2( + hf_hub_id='timm/', + license='other', origin_url='https://github.com/facebookresearch/vissl', + url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet64_finetuned_in1k_model_final_checkpoint_phase78.torch', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'regnety_1280.seer_ft_in1k': _cfgtv2( + hf_hub_id='timm/', + license='other', origin_url='https://github.com/facebookresearch/vissl', + url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet128_finetuned_in1k_model_final_checkpoint_phase78.torch', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'regnety_2560.seer_ft_in1k': _cfgtv2( + hf_hub_id='timm/', + license='other', origin_url='https://github.com/facebookresearch/vissl', + url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet256_finetuned_in1k_model_final_checkpoint_phase38.torch', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + + 'regnety_320.seer': _cfgtv2( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet32d/seer_regnet32gf_model_iteration244000.torch', + num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'), + 'regnety_640.seer': _cfgtv2( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet64/seer_regnet64gf_model_final_checkpoint_phase0.torch', + num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'), + 'regnety_1280.seer': _cfgtv2( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_regnet128Gf_cnstant_bs32_node16_sinkhorn10_proto16k_syncBN64_warmup8k/model_final_checkpoint_phase0.torch', + num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'), + # FIXME invalid weight <-> model match, mistake on their end + #'regnety_2560.seer': _cfgtv2( + # url='https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_cosine_rg256gf_noBNhead_wd1e5_fairstore_bs16_node64_sinkhorn10_proto16k_apex_syncBN64_warmup8k/model_final_checkpoint_phase0.torch', + # num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'), + + 'regnetx_002.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_004.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_006.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_008.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_016.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_032.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_040.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_064.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_080.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_120.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_160.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_320.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + + 'regnety_002.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_004.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_006.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_008.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_016.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_032.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_040.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_064.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_080.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_120.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_160.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_320.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), +}) + + +@register_model +def regnetx_002(pretrained=False, **kwargs) -> RegNet: + """RegNetX-200MF""" + return _create_regnet('regnetx_002', pretrained, **kwargs) + + +@register_model +def regnetx_004(pretrained=False, **kwargs) -> RegNet: + """RegNetX-400MF""" + return _create_regnet('regnetx_004', pretrained, **kwargs) + + +@register_model +def regnetx_004_tv(pretrained=False, **kwargs) -> RegNet: + """RegNetX-400MF w/ torchvision group rounding""" + return _create_regnet('regnetx_004_tv', pretrained, **kwargs) + + +@register_model +def regnetx_006(pretrained=False, **kwargs) -> RegNet: + """RegNetX-600MF""" + return _create_regnet('regnetx_006', pretrained, **kwargs) + + +@register_model +def regnetx_008(pretrained=False, **kwargs) -> RegNet: + """RegNetX-800MF""" + return _create_regnet('regnetx_008', pretrained, **kwargs) + + +@register_model +def regnetx_016(pretrained=False, **kwargs) -> RegNet: + """RegNetX-1.6GF""" + return _create_regnet('regnetx_016', pretrained, **kwargs) + + +@register_model +def regnetx_032(pretrained=False, **kwargs) -> RegNet: + """RegNetX-3.2GF""" + return _create_regnet('regnetx_032', pretrained, **kwargs) + + +@register_model +def regnetx_040(pretrained=False, **kwargs) -> RegNet: + """RegNetX-4.0GF""" + return _create_regnet('regnetx_040', pretrained, **kwargs) + + +@register_model +def regnetx_064(pretrained=False, **kwargs) -> RegNet: + """RegNetX-6.4GF""" + return _create_regnet('regnetx_064', pretrained, **kwargs) + + +@register_model +def regnetx_080(pretrained=False, **kwargs) -> RegNet: + """RegNetX-8.0GF""" + return _create_regnet('regnetx_080', pretrained, **kwargs) + + +@register_model +def regnetx_120(pretrained=False, **kwargs) -> RegNet: + """RegNetX-12GF""" + return _create_regnet('regnetx_120', pretrained, **kwargs) + + +@register_model +def regnetx_160(pretrained=False, **kwargs) -> RegNet: + """RegNetX-16GF""" + return _create_regnet('regnetx_160', pretrained, **kwargs) + + +@register_model +def regnetx_320(pretrained=False, **kwargs) -> RegNet: + """RegNetX-32GF""" + return _create_regnet('regnetx_320', pretrained, **kwargs) + + +@register_model +def regnety_002(pretrained=False, **kwargs) -> RegNet: + """RegNetY-200MF""" + return _create_regnet('regnety_002', pretrained, **kwargs) + + +@register_model +def regnety_004(pretrained=False, **kwargs) -> RegNet: + """RegNetY-400MF""" + return _create_regnet('regnety_004', pretrained, **kwargs) + + +@register_model +def regnety_006(pretrained=False, **kwargs) -> RegNet: + """RegNetY-600MF""" + return _create_regnet('regnety_006', pretrained, **kwargs) + + +@register_model +def regnety_008(pretrained=False, **kwargs) -> RegNet: + """RegNetY-800MF""" + return _create_regnet('regnety_008', pretrained, **kwargs) + + +@register_model +def regnety_008_tv(pretrained=False, **kwargs) -> RegNet: + """RegNetY-800MF w/ torchvision group rounding""" + return _create_regnet('regnety_008_tv', pretrained, **kwargs) + + +@register_model +def regnety_016(pretrained=False, **kwargs) -> RegNet: + """RegNetY-1.6GF""" + return _create_regnet('regnety_016', pretrained, **kwargs) + + +@register_model +def regnety_032(pretrained=False, **kwargs) -> RegNet: + """RegNetY-3.2GF""" + return _create_regnet('regnety_032', pretrained, **kwargs) + + +@register_model +def regnety_040(pretrained=False, **kwargs) -> RegNet: + """RegNetY-4.0GF""" + return _create_regnet('regnety_040', pretrained, **kwargs) + + +@register_model +def regnety_064(pretrained=False, **kwargs) -> RegNet: + """RegNetY-6.4GF""" + return _create_regnet('regnety_064', pretrained, **kwargs) + + +@register_model +def regnety_080(pretrained=False, **kwargs) -> RegNet: + """RegNetY-8.0GF""" + return _create_regnet('regnety_080', pretrained, **kwargs) + + +@register_model +def regnety_080_tv(pretrained=False, **kwargs) -> RegNet: + """RegNetY-8.0GF w/ torchvision group rounding""" + return _create_regnet('regnety_080_tv', pretrained, **kwargs) + + +@register_model +def regnety_120(pretrained=False, **kwargs) -> RegNet: + """RegNetY-12GF""" + return _create_regnet('regnety_120', pretrained, **kwargs) + + +@register_model +def regnety_160(pretrained=False, **kwargs) -> RegNet: + """RegNetY-16GF""" + return _create_regnet('regnety_160', pretrained, **kwargs) + + +@register_model +def regnety_320(pretrained=False, **kwargs) -> RegNet: + """RegNetY-32GF""" + return _create_regnet('regnety_320', pretrained, **kwargs) + + +@register_model +def regnety_640(pretrained=False, **kwargs) -> RegNet: + """RegNetY-64GF""" + return _create_regnet('regnety_640', pretrained, **kwargs) + + +@register_model +def regnety_1280(pretrained=False, **kwargs) -> RegNet: + """RegNetY-128GF""" + return _create_regnet('regnety_1280', pretrained, **kwargs) + + +@register_model +def regnety_2560(pretrained=False, **kwargs) -> RegNet: + """RegNetY-256GF""" + return _create_regnet('regnety_2560', pretrained, **kwargs) + + +@register_model +def regnety_040_sgn(pretrained=False, **kwargs) -> RegNet: + """RegNetY-4.0GF w/ GroupNorm """ + return _create_regnet('regnety_040_sgn', pretrained, **kwargs) + + +@register_model +def regnetv_040(pretrained=False, **kwargs) -> RegNet: + """RegNetV-4.0GF (pre-activation)""" + return _create_regnet('regnetv_040', pretrained, **kwargs) + + +@register_model +def regnetv_064(pretrained=False, **kwargs) -> RegNet: + """RegNetV-6.4GF (pre-activation)""" + return _create_regnet('regnetv_064', pretrained, **kwargs) + + +@register_model +def regnetz_005(pretrained=False, **kwargs) -> RegNet: + """RegNetZ-500MF + NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py + but it's not clear it is equivalent to paper model as not detailed in the paper. + """ + return _create_regnet('regnetz_005', pretrained, zero_init_last=False, **kwargs) + + +@register_model +def regnetz_040(pretrained=False, **kwargs) -> RegNet: + """RegNetZ-4.0GF + NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py + but it's not clear it is equivalent to paper model as not detailed in the paper. + """ + return _create_regnet('regnetz_040', pretrained, zero_init_last=False, **kwargs) + + +@register_model +def regnetz_040_h(pretrained=False, **kwargs) -> RegNet: + """RegNetZ-4.0GF + NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py + but it's not clear it is equivalent to paper model as not detailed in the paper. + """ + return _create_regnet('regnetz_040_h', pretrained, zero_init_last=False, **kwargs) + + +register_model_deprecations(__name__, { + 'regnetz_040h': 'regnetz_040_h', +}) \ No newline at end of file diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/resnest.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/resnest.py new file mode 100644 index 0000000000000000000000000000000000000000..5b1438017ea11dbd98968e07f64120bcb66a6aac --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/resnest.py @@ -0,0 +1,251 @@ +""" ResNeSt Models + +Paper: `ResNeSt: Split-Attention Networks` - https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt by Hang Zhang + +Modified for torchscript compat, and consistency with timm by Ross Wightman +""" +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import SplitAttn +from ._builder import build_model_with_cfg +from ._registry import register_model, generate_default_cfgs +from .resnet import ResNet + + +class ResNestBottleneck(nn.Module): + """ResNet Bottleneck + """ + # pylint: disable=unused-argument + expansion = 4 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + radix=1, + cardinality=1, + base_width=64, + avd=False, + avd_first=False, + is_first=False, + reduce_first=1, + dilation=1, + first_dilation=None, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + attn_layer=None, + aa_layer=None, + drop_block=None, + drop_path=None, + ): + super(ResNestBottleneck, self).__init__() + assert reduce_first == 1 # not supported + assert attn_layer is None # not supported + assert aa_layer is None # TODO not yet supported + assert drop_path is None # TODO not yet supported + + group_width = int(planes * (base_width / 64.)) * cardinality + first_dilation = first_dilation or dilation + if avd and (stride > 1 or is_first): + avd_stride = stride + stride = 1 + else: + avd_stride = 0 + self.radix = radix + + self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) + self.bn1 = norm_layer(group_width) + self.act1 = act_layer(inplace=True) + self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None + + if self.radix >= 1: + self.conv2 = SplitAttn( + group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_layer=drop_block) + self.bn2 = nn.Identity() + self.drop_block = nn.Identity() + self.act2 = nn.Identity() + else: + self.conv2 = nn.Conv2d( + group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, bias=False) + self.bn2 = norm_layer(group_width) + self.drop_block = drop_block() if drop_block is not None else nn.Identity() + self.act2 = act_layer(inplace=True) + self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None + + self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes*4) + self.act3 = act_layer(inplace=True) + self.downsample = downsample + + def zero_init_last(self): + if getattr(self.bn3, 'weight', None) is not None: + nn.init.zeros_(self.bn3.weight) + + def forward(self, x): + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.act1(out) + + if self.avd_first is not None: + out = self.avd_first(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.drop_block(out) + out = self.act2(out) + + if self.avd_last is not None: + out = self.avd_last(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + shortcut = self.downsample(x) + + out += shortcut + out = self.act3(out) + return out + + +def _create_resnest(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ResNet, + variant, + pretrained, + **kwargs, + ) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1.0', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'resnest14d.gluon_in1k': _cfg(hf_hub_id='timm/'), + 'resnest26d.gluon_in1k': _cfg(hf_hub_id='timm/'), + 'resnest50d.in1k': _cfg(hf_hub_id='timm/'), + 'resnest101e.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'resnest200e.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=0.909, interpolation='bicubic'), + 'resnest269e.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.928, interpolation='bicubic'), + 'resnest50d_4s2x40d.in1k': _cfg( + hf_hub_id='timm/', + interpolation='bicubic'), + 'resnest50d_1s4x24d.in1k': _cfg( + hf_hub_id='timm/', + interpolation='bicubic') +}) + + +@register_model +def resnest14d(pretrained=False, **kwargs) -> ResNet: + """ ResNeSt-14d model. Weights ported from GluonCV. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[1, 1, 1, 1], + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False)) + return _create_resnest('resnest14d', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + + +@register_model +def resnest26d(pretrained=False, **kwargs) -> ResNet: + """ ResNeSt-26d model. Weights ported from GluonCV. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[2, 2, 2, 2], + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False)) + return _create_resnest('resnest26d', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + + +@register_model +def resnest50d(pretrained=False, **kwargs) -> ResNet: + """ ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 6, 3], + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False)) + return _create_resnest('resnest50d', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + + +@register_model +def resnest101e(pretrained=False, **kwargs) -> ResNet: + """ ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 23, 3], + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False)) + return _create_resnest('resnest101e', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + + +@register_model +def resnest200e(pretrained=False, **kwargs) -> ResNet: + """ ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 24, 36, 3], + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False)) + return _create_resnest('resnest200e', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + + +@register_model +def resnest269e(pretrained=False, **kwargs) -> ResNet: + """ ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 30, 48, 8], + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False)) + return _create_resnest('resnest269e', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + + +@register_model +def resnest50d_4s2x40d(pretrained=False, **kwargs) -> ResNet: + """ResNeSt-50 4s2x40d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 6, 3], + stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2, + block_args=dict(radix=4, avd=True, avd_first=True)) + return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + + +@register_model +def resnest50d_1s4x24d(pretrained=False, **kwargs) -> ResNet: + """ResNeSt-50 1s4x24d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 6, 3], + stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4, + block_args=dict(radix=1, avd=True, avd_first=True)) + return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **dict(model_kwargs, **kwargs)) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/resnet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..69e2894659550983f3d4aa712650bfccbfb3bca5 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/resnet.py @@ -0,0 +1,2025 @@ +"""PyTorch ResNet + +This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with +additional dropout and dynamic global avg/max pool. + +ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman + +Copyright 2019, Ross Wightman +""" +import math +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \ + get_attn, get_act_layer, get_norm_layer, create_classifier +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs, register_model_deprecations + +__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this + + +def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +def create_aa(aa_layer: Type[nn.Module], channels: int, stride: int = 2, enable: bool = True) -> nn.Module: + if not aa_layer or not enable: + return nn.Identity() + if issubclass(aa_layer, nn.AvgPool2d): + return aa_layer(stride) + else: + return aa_layer(channels=channels, stride=stride) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + cardinality: int = 1, + base_width: int = 64, + reduce_first: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + attn_layer: Optional[Type[nn.Module]] = None, + aa_layer: Optional[Type[nn.Module]] = None, + drop_block: Optional[Type[nn.Module]] = None, + drop_path: Optional[nn.Module] = None, + ): + """ + Args: + inplanes: Input channel dimensionality. + planes: Used to determine output channel dimensionalities. + stride: Stride used in convolution layers. + downsample: Optional downsample layer for residual path. + cardinality: Number of convolution groups. + base_width: Base width used to determine output channel dimensionality. + reduce_first: Reduction factor for first convolution output width of residual blocks. + dilation: Dilation rate for convolution layers. + first_dilation: Dilation rate for first convolution layer. + act_layer: Activation layer. + norm_layer: Normalization layer. + attn_layer: Attention layer. + aa_layer: Anti-aliasing layer. + drop_block: Class for DropBlock layer. + drop_path: Optional DropPath layer. + """ + super(BasicBlock, self).__init__() + + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock does not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) + + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation, + dilation=first_dilation, bias=False) + self.bn1 = norm_layer(first_planes) + self.drop_block = drop_block() if drop_block is not None else nn.Identity() + self.act1 = act_layer(inplace=True) + self.aa = create_aa(aa_layer, channels=first_planes, stride=stride, enable=use_aa) + + self.conv2 = nn.Conv2d( + first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) + self.bn2 = norm_layer(outplanes) + + self.se = create_attn(attn_layer, outplanes) + + self.act2 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_path = drop_path + + def zero_init_last(self): + if getattr(self.bn2, 'weight', None) is not None: + nn.init.zeros_(self.bn2.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + + x = self.conv1(x) + x = self.bn1(x) + x = self.drop_block(x) + x = self.act1(x) + x = self.aa(x) + + x = self.conv2(x) + x = self.bn2(x) + + if self.se is not None: + x = self.se(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act2(x) + + return x + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + cardinality: int = 1, + base_width: int = 64, + reduce_first: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + attn_layer: Optional[Type[nn.Module]] = None, + aa_layer: Optional[Type[nn.Module]] = None, + drop_block: Optional[Type[nn.Module]] = None, + drop_path: Optional[nn.Module] = None, + ): + """ + Args: + inplanes: Input channel dimensionality. + planes: Used to determine output channel dimensionalities. + stride: Stride used in convolution layers. + downsample: Optional downsample layer for residual path. + cardinality: Number of convolution groups. + base_width: Base width used to determine output channel dimensionality. + reduce_first: Reduction factor for first convolution output width of residual blocks. + dilation: Dilation rate for convolution layers. + first_dilation: Dilation rate for first convolution layer. + act_layer: Activation layer. + norm_layer: Normalization layer. + attn_layer: Attention layer. + aa_layer: Anti-aliasing layer. + drop_block: Class for DropBlock layer. + drop_path: Optional DropPath layer. + """ + super(Bottleneck, self).__init__() + + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) + + self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + + self.conv2 = nn.Conv2d( + first_planes, width, kernel_size=3, stride=1 if use_aa else stride, + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) + self.bn2 = norm_layer(width) + self.drop_block = drop_block() if drop_block is not None else nn.Identity() + self.act2 = act_layer(inplace=True) + self.aa = create_aa(aa_layer, channels=width, stride=stride, enable=use_aa) + + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + + self.se = create_attn(attn_layer, outplanes) + + self.act3 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_path = drop_path + + def zero_init_last(self): + if getattr(self.bn3, 'weight', None) is not None: + nn.init.zeros_(self.bn3.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.drop_block(x) + x = self.act2(x) + x = self.aa(x) + + x = self.conv3(x) + x = self.bn3(x) + + if self.se is not None: + x = self.se(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act3(x) + + return x + + +def downsample_conv( + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + norm_layer: Optional[Type[nn.Module]] = None, +) -> nn.Module: + norm_layer = norm_layer or nn.BatchNorm2d + kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size + first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1 + p = get_padding(kernel_size, stride, first_dilation) + + return nn.Sequential(*[ + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False), + norm_layer(out_channels) + ]) + + +def downsample_avg( + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + norm_layer: Optional[Type[nn.Module]] = None, +) -> nn.Module: + norm_layer = norm_layer or nn.BatchNorm2d + avg_stride = stride if dilation == 1 else 1 + if stride == 1 and dilation == 1: + pool = nn.Identity() + else: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + + return nn.Sequential(*[ + pool, + nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False), + norm_layer(out_channels) + ]) + + +def drop_blocks(drop_prob: float = 0.): + return [ + None, None, + partial(DropBlock2d, drop_prob=drop_prob, block_size=5, gamma_scale=0.25) if drop_prob else None, + partial(DropBlock2d, drop_prob=drop_prob, block_size=3, gamma_scale=1.00) if drop_prob else None] + + +def make_blocks( + block_fn: Union[BasicBlock, Bottleneck], + channels: List[int], + block_repeats: List[int], + inplanes: int, + reduce_first: int = 1, + output_stride: int = 32, + down_kernel_size: int = 1, + avg_down: bool = False, + drop_block_rate: float = 0., + drop_path_rate: float = 0., + **kwargs, +) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]: + stages = [] + feature_info = [] + net_num_blocks = sum(block_repeats) + net_block_idx = 0 + net_stride = 4 + dilation = prev_dilation = 1 + for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))): + stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it + stride = 1 if stage_idx == 0 else 2 + if net_stride >= output_stride: + dilation *= stride + stride = 1 + else: + net_stride *= stride + + downsample = None + if stride != 1 or inplanes != planes * block_fn.expansion: + down_kwargs = dict( + in_channels=inplanes, + out_channels=planes * block_fn.expansion, + kernel_size=down_kernel_size, + stride=stride, + dilation=dilation, + first_dilation=prev_dilation, + norm_layer=kwargs.get('norm_layer'), + ) + downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs) + + block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs) + blocks = [] + for block_idx in range(num_blocks): + downsample = downsample if block_idx == 0 else None + stride = stride if block_idx == 0 else 1 + block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule + blocks.append(block_fn( + inplanes, + planes, + stride, + downsample, + first_dilation=prev_dilation, + drop_path=DropPath(block_dpr) if block_dpr > 0. else None, + **block_kwargs, + )) + prev_dilation = dilation + inplanes = planes * block_fn.expansion + net_block_idx += 1 + + stages.append((stage_name, nn.Sequential(*blocks))) + feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name)) + + return stages, feature_info + + +class ResNet(nn.Module): + """ResNet / ResNeXt / SE-ResNeXt / SE-Net + + This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that + * have > 1 stride in the 3x3 conv layer of bottleneck + * have conv-bn-act ordering + + This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s + variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the + 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default. + + ResNet variants (the same modifications can be used in SE/ResNeXt models as well): + * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b + * c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64) + * d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample + * e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample + * s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128) + * t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample + * tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample + + ResNeXt + * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths + * same c,d, e, s variants as ResNet can be enabled + + SE-ResNeXt + * normal - 7x7 stem, stem_width = 64 + * same c, d, e, s variants as ResNet can be enabled + + SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64, + reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block + """ + + def __init__( + self, + block: Union[BasicBlock, Bottleneck], + layers: List[int], + num_classes: int = 1000, + in_chans: int = 3, + output_stride: int = 32, + global_pool: str = 'avg', + cardinality: int = 1, + base_width: int = 64, + stem_width: int = 64, + stem_type: str = '', + replace_stem_pool: bool = False, + block_reduce_first: int = 1, + down_kernel_size: int = 1, + avg_down: bool = False, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[Type[nn.Module]] = None, + drop_rate: float = 0.0, + drop_path_rate: float = 0., + drop_block_rate: float = 0., + zero_init_last: bool = True, + block_args: Optional[Dict[str, Any]] = None, + ): + """ + Args: + block (nn.Module): class for the residual block. Options are BasicBlock, Bottleneck. + layers (List[int]) : number of layers in each block + num_classes (int): number of classification classes (default 1000) + in_chans (int): number of input (color) channels. (default 3) + output_stride (int): output stride of the network, 32, 16, or 8. (default 32) + global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg') + cardinality (int): number of convolution groups for 3x3 conv in Bottleneck. (default 1) + base_width (int): bottleneck channels factor. `planes * base_width / 64 * cardinality` (default 64) + stem_width (int): number of channels in stem convolutions (default 64) + stem_type (str): The type of stem (default ''): + * '', default - a single 7x7 conv with a width of stem_width + * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 + * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2 + replace_stem_pool (bool): replace stem max-pooling layer with a 3x3 stride-2 convolution + block_reduce_first (int): Reduction factor for first convolution output width of residual blocks, + 1 for all archs except senets, where 2 (default 1) + down_kernel_size (int): kernel size of residual block downsample path, + 1x1 for most, 3x3 for senets (default: 1) + avg_down (bool): use avg pooling for projection skip connection between stages/downsample (default False) + act_layer (str, nn.Module): activation layer + norm_layer (str, nn.Module): normalization layer + aa_layer (nn.Module): anti-aliasing layer + drop_rate (float): Dropout probability before classifier, for training (default 0.) + drop_path_rate (float): Stochastic depth drop-path rate (default 0.) + drop_block_rate (float): Drop block rate (default 0.) + zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight) + block_args (dict): Extra kwargs to pass through to block module + """ + super(ResNet, self).__init__() + block_args = block_args or dict() + assert output_stride in (8, 16, 32) + self.num_classes = num_classes + self.drop_rate = drop_rate + self.grad_checkpointing = False + + act_layer = get_act_layer(act_layer) + norm_layer = get_norm_layer(norm_layer) + + # Stem + deep_stem = 'deep' in stem_type + inplanes = stem_width * 2 if deep_stem else 64 + if deep_stem: + stem_chs = (stem_width, stem_width) + if 'tiered' in stem_type: + stem_chs = (3 * (stem_width // 4), stem_width) + self.conv1 = nn.Sequential(*[ + nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False), + norm_layer(stem_chs[0]), + act_layer(inplace=True), + nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False), + norm_layer(stem_chs[1]), + act_layer(inplace=True), + nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)]) + else: + self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(inplanes) + self.act1 = act_layer(inplace=True) + self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')] + + # Stem pooling. The name 'maxpool' remains for weight compatibility. + if replace_stem_pool: + self.maxpool = nn.Sequential(*filter(None, [ + nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False), + create_aa(aa_layer, channels=inplanes, stride=2) if aa_layer is not None else None, + norm_layer(inplanes), + act_layer(inplace=True), + ])) + else: + if aa_layer is not None: + if issubclass(aa_layer, nn.AvgPool2d): + self.maxpool = aa_layer(2) + else: + self.maxpool = nn.Sequential(*[ + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + aa_layer(channels=inplanes, stride=2)]) + else: + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + # Feature Blocks + channels = [64, 128, 256, 512] + stage_modules, stage_feature_info = make_blocks( + block, + channels, + layers, + inplanes, + cardinality=cardinality, + base_width=base_width, + output_stride=output_stride, + reduce_first=block_reduce_first, + avg_down=avg_down, + down_kernel_size=down_kernel_size, + act_layer=act_layer, + norm_layer=norm_layer, + aa_layer=aa_layer, + drop_block_rate=drop_block_rate, + drop_path_rate=drop_path_rate, + **block_args, + ) + for stage in stage_modules: + self.add_module(*stage) # layer1, layer2, etc + self.feature_info.extend(stage_feature_info) + + # Head (Pooling and Classifier) + self.num_features = 512 * block.expansion + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + self.init_weights(zero_init_last=zero_init_last) + + @torch.jit.ignore + def init_weights(self, zero_init_last: bool = True): + for n, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if zero_init_last: + for m in self.modules(): + if hasattr(m, 'zero_init_last'): + m.zero_init_last() + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False): + matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)') + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self, name_only: bool = False): + return 'fc' if name_only else self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.maxpool(x) + + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq([self.layer1, self.layer2, self.layer3, self.layer4], x, flatten=True) + else: + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + x = self.global_pool(x) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + return x if pre_logits else self.fc(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_resnet(variant, pretrained: bool = False, **kwargs) -> ResNet: + return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +def _tcfg(url='', **kwargs): + return _cfg(url=url, **dict({'interpolation': 'bicubic'}, **kwargs)) + + +def _ttcfg(url='', **kwargs): + return _cfg(url=url, **dict({ + 'interpolation': 'bicubic', 'test_input_size': (3, 288, 288), 'test_crop_pct': 0.95, + 'origin_url': 'https://github.com/huggingface/pytorch-image-models', + }, **kwargs)) + + +def _rcfg(url='', **kwargs): + return _cfg(url=url, **dict({ + 'interpolation': 'bicubic', 'crop_pct': 0.95, 'test_input_size': (3, 288, 288), 'test_crop_pct': 1.0, + 'origin_url': 'https://github.com/huggingface/pytorch-image-models', 'paper_ids': 'arXiv:2110.00476' + }, **kwargs)) + + +def _r3cfg(url='', **kwargs): + return _cfg(url=url, **dict({ + 'interpolation': 'bicubic', 'input_size': (3, 160, 160), 'pool_size': (5, 5), + 'crop_pct': 0.95, 'test_input_size': (3, 224, 224), 'test_crop_pct': 0.95, + 'origin_url': 'https://github.com/huggingface/pytorch-image-models', 'paper_ids': 'arXiv:2110.00476', + }, **kwargs)) + + +def _gcfg(url='', **kwargs): + return _cfg(url=url, **dict({ + 'interpolation': 'bicubic', + 'origin_url': 'https://cv.gluon.ai/model_zoo/classification.html', + }, **kwargs)) + + +default_cfgs = generate_default_cfgs({ + # ResNet and Wide ResNet trained w/ timm (RSB paper and others) + 'resnet10t.c3_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet10t_176_c3-f3215ab1.pth', + input_size=(3, 176, 176), pool_size=(6, 6), test_crop_pct=0.95, test_input_size=(3, 224, 224), + first_conv='conv1.0'), + 'resnet14t.c3_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet14t_176_c3-c4ed2c37.pth', + input_size=(3, 176, 176), pool_size=(6, 6), test_crop_pct=0.95, test_input_size=(3, 224, 224), + first_conv='conv1.0'), + 'resnet18.a1_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet18_a1_0-d63eafa0.pth'), + 'resnet18.a2_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet18_a2_0-b61bd467.pth'), + 'resnet18.a3_in1k': _r3cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet18_a3_0-40c531c8.pth'), + 'resnet18d.ra2_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth', + first_conv='conv1.0'), + 'resnet34.a1_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet34_a1_0-46f8f793.pth'), + 'resnet34.a2_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet34_a2_0-82d47d71.pth'), + 'resnet34.a3_in1k': _r3cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet34_a3_0-a20cabb6.pth', + crop_pct=0.95), + 'resnet34.bt_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'), + 'resnet34d.ra2_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34d_ra2-f8dcfcaf.pth', + first_conv='conv1.0'), + 'resnet26.bt_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26-9aa10e23.pth'), + 'resnet26d.bt_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth', + first_conv='conv1.0'), + 'resnet26t.ra2_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet26t_256_ra2-6f6fa748.pth', + first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=0.94, test_input_size=(3, 320, 320), test_crop_pct=1.0), + 'resnet50.a1_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth'), + 'resnet50.a1h_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth', + input_size=(3, 176, 176), pool_size=(6, 6), crop_pct=0.9, test_input_size=(3, 224, 224), test_crop_pct=1.0), + 'resnet50.a2_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a2_0-a2746f79.pth'), + 'resnet50.a3_in1k': _r3cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a3_0-59cae1ef.pth'), + 'resnet50.b1k_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_b1k-532a802a.pth'), + 'resnet50.b2k_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_b2k-1ba180c1.pth'), + 'resnet50.c1_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_c1-5ba5e060.pth'), + 'resnet50.c2_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_c2-d01e05b2.pth'), + 'resnet50.d_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_d-f39db8af.pth'), + 'resnet50.ram_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth'), + 'resnet50.am_in1k': _tcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/resnet50_am-6c502b37.pth'), + 'resnet50.ra_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/resnet50_ra-85ebb6e5.pth'), + 'resnet50.bt_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth'), + 'resnet50d.ra2_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', + first_conv='conv1.0'), + 'resnet50d.a1_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50d_a1_0-e20cff14.pth', + first_conv='conv1.0'), + 'resnet50d.a2_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50d_a2_0-a3adc64d.pth', + first_conv='conv1.0'), + 'resnet50d.a3_in1k': _r3cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50d_a3_0-403fdfad.pth', + first_conv='conv1.0'), + 'resnet50t.untrained': _ttcfg(first_conv='conv1.0'), + 'resnet101.a1h_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth'), + 'resnet101.a1_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1_0-cdcb52a9.pth'), + 'resnet101.a2_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a2_0-6edb36c7.pth'), + 'resnet101.a3_in1k': _r3cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a3_0-1db14157.pth'), + 'resnet101d.ra2_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet101d_ra2-2803ffab.pth', + first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, + test_crop_pct=1.0, test_input_size=(3, 320, 320)), + 'resnet152.a1h_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet152_a1h-dc400468.pth'), + 'resnet152.a1_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet152_a1_0-2eee8a7a.pth'), + 'resnet152.a2_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet152_a2_0-b4c6978f.pth'), + 'resnet152.a3_in1k': _r3cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet152_a3_0-134d4688.pth'), + 'resnet152d.ra2_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth', + first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, + test_crop_pct=1.0, test_input_size=(3, 320, 320)), + 'resnet200.untrained': _ttcfg(), + 'resnet200d.ra2_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet200d_ra2-bdba9bf9.pth', + first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, + test_crop_pct=1.0, test_input_size=(3, 320, 320)), + 'wide_resnet50_2.racm_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/wide_resnet50_racm-8234f177.pth'), + + # torchvision resnet weights + 'resnet18.tv_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/resnet18-5c106cde.pth', + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + 'resnet34.tv_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/resnet34-333f7ec4.pth', + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + 'resnet50.tv_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/resnet50-19c8e357.pth', + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + 'resnet50.tv2_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/resnet50-11ad3fa6.pth', + input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965, + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + 'resnet101.tv_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + 'resnet101.tv2_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/resnet101-cd907fc2.pth', + input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965, + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + 'resnet152.tv_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/resnet152-b121ed2d.pth', + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + 'resnet152.tv2_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/resnet152-f82ba261.pth', + input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965, + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + 'wide_resnet50_2.tv_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + 'wide_resnet50_2.tv2_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth', + input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965, + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + 'wide_resnet101_2.tv_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + 'wide_resnet101_2.tv2_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth', + input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965, + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + + # ResNets w/ alternative norm layers + 'resnet50_gn.a1h_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_gn_a1h2-8fe6c4d0.pth', + crop_pct=0.94), + + # ResNeXt trained in timm (RSB paper and others) + 'resnext50_32x4d.a1h_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a1h-0146ab0a.pth'), + 'resnext50_32x4d.a1_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a1_0-b5a91a1d.pth'), + 'resnext50_32x4d.a2_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a2_0-efc76add.pth'), + 'resnext50_32x4d.a3_in1k': _r3cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a3_0-3e450271.pth'), + 'resnext50_32x4d.ra_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d_ra-d733960d.pth'), + 'resnext50d_32x4d.bt_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth', + first_conv='conv1.0'), + 'resnext101_32x4d.untrained': _ttcfg(), + 'resnext101_64x4d.c1_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnext101_64x4d_c-0d0e0cc0.pth'), + + # torchvision ResNeXt weights + 'resnext50_32x4d.tv_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + 'resnext101_32x8d.tv_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + 'resnext101_64x4d.tv_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth', + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + 'resnext50_32x4d.tv2_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth', + input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965, + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + 'resnext101_32x8d.tv2_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth', + input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965, + license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'), + + # ResNeXt models - Weakly Supervised Pretraining on Instagram Hashtags + # from https://github.com/facebookresearch/WSL-Images + # Please note the CC-BY-NC 4.0 license on these weights, non-commercial use only. + 'resnext101_32x8d.fb_wsl_ig1b_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/WSL-Images'), + 'resnext101_32x16d.fb_wsl_ig1b_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/WSL-Images'), + 'resnext101_32x32d.fb_wsl_ig1b_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/WSL-Images'), + 'resnext101_32x48d.fb_wsl_ig1b_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/WSL-Images'), + + # Semi-Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models + # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only. + 'resnet18.fb_ssl_yfcc100m_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'), + 'resnet50.fb_ssl_yfcc100m_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'), + 'resnext50_32x4d.fb_ssl_yfcc100m_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'), + 'resnext101_32x4d.fb_ssl_yfcc100m_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'), + 'resnext101_32x8d.fb_ssl_yfcc100m_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'), + 'resnext101_32x16d.fb_ssl_yfcc100m_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'), + + # Semi-Weakly Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models + # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only. + 'resnet18.fb_swsl_ig1b_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'), + 'resnet50.fb_swsl_ig1b_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'), + 'resnext50_32x4d.fb_swsl_ig1b_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'), + 'resnext101_32x4d.fb_swsl_ig1b_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'), + 'resnext101_32x8d.fb_swsl_ig1b_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'), + 'resnext101_32x16d.fb_swsl_ig1b_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth', + license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'), + + # Efficient Channel Attention ResNets + 'ecaresnet26t.ra2_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth', + first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + test_crop_pct=0.95, test_input_size=(3, 320, 320)), + 'ecaresnetlight.miil_in1k': _tcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/ecaresnetlight-75a9c627.pth', + test_crop_pct=0.95, test_input_size=(3, 288, 288)), + 'ecaresnet50d.miil_in1k': _tcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/ecaresnet50d-93c81e3b.pth', + first_conv='conv1.0', test_crop_pct=0.95, test_input_size=(3, 288, 288)), + 'ecaresnet50d_pruned.miil_in1k': _tcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/ecaresnet50d_p-e4fa23c2.pth', + first_conv='conv1.0', test_crop_pct=0.95, test_input_size=(3, 288, 288)), + 'ecaresnet50t.ra2_in1k': _tcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth', + first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + test_crop_pct=0.95, test_input_size=(3, 320, 320)), + 'ecaresnet50t.a1_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/ecaresnet50t_a1_0-99bd76a8.pth', + first_conv='conv1.0'), + 'ecaresnet50t.a2_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/ecaresnet50t_a2_0-b1c7b745.pth', + first_conv='conv1.0'), + 'ecaresnet50t.a3_in1k': _r3cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/ecaresnet50t_a3_0-8cc311f1.pth', + first_conv='conv1.0'), + 'ecaresnet101d.miil_in1k': _tcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/ecaresnet101d-153dad65.pth', + first_conv='conv1.0', test_crop_pct=0.95, test_input_size=(3, 288, 288)), + 'ecaresnet101d_pruned.miil_in1k': _tcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/ecaresnet101d_p-9e74cb91.pth', + first_conv='conv1.0', test_crop_pct=0.95, test_input_size=(3, 288, 288)), + 'ecaresnet200d.untrained': _ttcfg( + first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.95, pool_size=(8, 8)), + 'ecaresnet269d.ra2_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth', + first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=0.95, + test_crop_pct=1.0, test_input_size=(3, 352, 352)), + + # Efficient Channel Attention ResNeXts + 'ecaresnext26t_32x4d.untrained': _tcfg(first_conv='conv1.0'), + 'ecaresnext50t_32x4d.untrained': _tcfg(first_conv='conv1.0'), + + # Squeeze-Excitation ResNets, to eventually replace the models in senet.py + 'seresnet18.untrained': _ttcfg(), + 'seresnet34.untrained': _ttcfg(), + 'seresnet50.a1_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/seresnet50_a1_0-ffa00869.pth', + crop_pct=0.95), + 'seresnet50.a2_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/seresnet50_a2_0-850de0d9.pth', + crop_pct=0.95), + 'seresnet50.a3_in1k': _r3cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/seresnet50_a3_0-317ecd56.pth', + crop_pct=0.95), + 'seresnet50.ra2_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet50_ra_224-8efdb4bb.pth'), + 'seresnet50t.untrained': _ttcfg( + first_conv='conv1.0'), + 'seresnet101.untrained': _ttcfg(), + 'seresnet152.untrained': _ttcfg(), + 'seresnet152d.ra2_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet152d_ra2-04464dd2.pth', + first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, + test_crop_pct=1.0, test_input_size=(3, 320, 320) + ), + 'seresnet200d.untrained': _ttcfg( + first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)), + 'seresnet269d.untrained': _ttcfg( + first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)), + + # Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py + 'seresnext26d_32x4d.bt_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth', + first_conv='conv1.0'), + 'seresnext26t_32x4d.bt_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth', + first_conv='conv1.0'), + 'seresnext50_32x4d.racm_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext50_32x4d_racm-a304a460.pth'), + 'seresnext101_32x4d.untrained': _ttcfg(), + 'seresnext101_32x8d.ah_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101_32x8d_ah-e6bc4c0a.pth'), + 'seresnext101d_32x8d.ah_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101d_32x8d_ah-191d7b94.pth', + first_conv='conv1.0'), + + # ResNets with anti-aliasing / blur pool + 'resnetaa50d.sw_in12k_ft_in1k': _ttcfg( + hf_hub_id='timm/', + first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0), + 'resnetaa101d.sw_in12k_ft_in1k': _ttcfg( + hf_hub_id='timm/', + first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0), + 'seresnextaa101d_32x8d.sw_in12k_ft_in1k_288': _ttcfg( + hf_hub_id='timm/', + crop_pct=0.95, input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), test_crop_pct=1.0, + first_conv='conv1.0'), + 'seresnextaa101d_32x8d.sw_in12k_ft_in1k': _ttcfg( + hf_hub_id='timm/', + first_conv='conv1.0', test_crop_pct=1.0), + 'seresnextaa201d_32x8d.sw_in12k_ft_in1k_384': _cfg( + hf_hub_id='timm/', + interpolation='bicubic', first_conv='conv1.0', pool_size=(12, 12), input_size=(3, 384, 384), crop_pct=1.0), + 'seresnextaa201d_32x8d.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, interpolation='bicubic', first_conv='conv1.0', + crop_pct=0.95, input_size=(3, 320, 320), pool_size=(10, 10), test_input_size=(3, 384, 384), test_crop_pct=1.0), + + 'resnetaa50d.sw_in12k': _ttcfg( + hf_hub_id='timm/', + num_classes=11821, first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0), + 'resnetaa50d.d_in12k': _ttcfg( + hf_hub_id='timm/', + num_classes=11821, first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0), + 'resnetaa101d.sw_in12k': _ttcfg( + hf_hub_id='timm/', + num_classes=11821, first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0), + 'seresnextaa101d_32x8d.sw_in12k': _ttcfg( + hf_hub_id='timm/', + num_classes=11821, first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0), + + 'resnetblur18.untrained': _ttcfg(), + 'resnetblur50.bt_in1k': _ttcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth'), + 'resnetblur50d.untrained': _ttcfg(first_conv='conv1.0'), + 'resnetblur101d.untrained': _ttcfg(first_conv='conv1.0'), + 'resnetaa34d.untrained': _ttcfg(first_conv='conv1.0'), + 'resnetaa50.a1h_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetaa50_a1h-4cf422b3.pth'), + + 'seresnetaa50d.untrained': _ttcfg(first_conv='conv1.0'), + 'seresnextaa101d_32x8d.ah_in1k': _rcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnextaa101d_32x8d_ah-83c8ae12.pth', + first_conv='conv1.0'), + + # ResNet-RS models + 'resnetrs50.tf_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50_ema-6b53758b.pth', + input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.91, test_input_size=(3, 224, 224), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs101.tf_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101_i192_ema-1509bbf6.pth', + input_size=(3, 192, 192), pool_size=(6, 6), crop_pct=0.94, test_input_size=(3, 288, 288), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs152.tf_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs152_i256_ema-a9aff7f9.pth', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs200.tf_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnetrs200_c-6b698b88.pth', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs270.tf_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs270_ema-b40e674c.pth', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 352, 352), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs350.tf_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs350_i256_ema-5a1aa8f1.pth', + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, test_input_size=(3, 384, 384), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs420.tf_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs420_ema-972dee69.pth', + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, test_input_size=(3, 416, 416), + interpolation='bicubic', first_conv='conv1.0'), + + # gluon resnet weights + 'resnet18.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet18_v1b-0757602b.pth'), + 'resnet34.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet34_v1b-c6d82d59.pth'), + 'resnet50.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1b-0ebe02e2.pth'), + 'resnet101.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1b-3b017079.pth'), + 'resnet152.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1b-c1edb0dd.pth'), + 'resnet50c.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1c-48092f55.pth', + first_conv='conv1.0'), + 'resnet101c.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1c-1f26822a.pth', + first_conv='conv1.0'), + 'resnet152c.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1c-a3bb0b98.pth', + first_conv='conv1.0'), + 'resnet50d.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1d-818a1b1b.pth', + first_conv='conv1.0'), + 'resnet101d.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1d-0f9c8644.pth', + first_conv='conv1.0'), + 'resnet152d.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1d-bd354e12.pth', + first_conv='conv1.0'), + 'resnet50s.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth', + first_conv='conv1.0'), + 'resnet101s.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1s-60fe0cc1.pth', + first_conv='conv1.0'), + 'resnet152s.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1s-dcc41b81.pth', + first_conv='conv1.0'), + 'resnext50_32x4d.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'), + 'resnext101_32x4d.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'), + 'resnext101_64x4d.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'), + 'seresnext50_32x4d.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'), + 'seresnext101_32x4d.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'), + 'seresnext101_64x4d.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'), + 'senet154.gluon_in1k': _gcfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth', + first_conv='conv1.0'), +}) + + +@register_model +def resnet10t(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-10-T model. + """ + model_args = dict(block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True) + return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet14t(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-14-T model. + """ + model_args = dict(block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True) + return _create_resnet('resnet14t', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet18(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-18 model. + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2]) + return _create_resnet('resnet18', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet18d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-18-D model. + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnet18d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet34(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-34 model. + """ + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3]) + return _create_resnet('resnet34', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet34d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-34-D model. + """ + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnet34d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet26(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-26 model. + """ + model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2]) + return _create_resnet('resnet26', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet26t(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-26-T model. + """ + model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True) + return _create_resnet('resnet26t', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet26d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-26-D model. + """ + model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnet26d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet50(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-50 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3]) + return _create_resnet('resnet50', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet50c(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-50-C model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep') + return _create_resnet('resnet50c', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet50d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-50-D model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnet50d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet50s(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-50-S model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=64, stem_type='deep') + return _create_resnet('resnet50s', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet50t(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-50-T model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True) + return _create_resnet('resnet50t', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet101(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3]) + return _create_resnet('resnet101', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet101c(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-101-C model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep') + return _create_resnet('resnet101c', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet101d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-101-D model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnet101d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet101s(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-101-S model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=64, stem_type='deep') + return _create_resnet('resnet101s', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet152(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-152 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3]) + return _create_resnet('resnet152', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet152c(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-152-C model. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep') + return _create_resnet('resnet152c', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet152d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-152-D model. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnet152d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet152s(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-152-S model. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=64, stem_type='deep') + return _create_resnet('resnet152s', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet200(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-200 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3]) + return _create_resnet('resnet200', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet200d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-200-D model. + """ + model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnet200d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def wide_resnet50_2(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a Wide ResNet-50-2 model. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128) + return _create_resnet('wide_resnet50_2', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def wide_resnet101_2(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a Wide ResNet-101-2 model. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128) + return _create_resnet('wide_resnet101_2', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnet50_gn(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-50 model w/ GroupNorm + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], norm_layer='groupnorm') + return _create_resnet('resnet50_gn', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNeXt50-32x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4) + return _create_resnet('resnext50_32x4d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnext50d_32x4d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnext50d_32x4d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNeXt-101 32x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4) + return _create_resnet('resnext101_32x4d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNeXt-101 32x8d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8) + return _create_resnet('resnext101_32x8d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnext101_32x16d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNeXt-101 32x16d model + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16) + return _create_resnet('resnext101_32x16d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnext101_32x32d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNeXt-101 32x32d model + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32) + return _create_resnet('resnext101_32x32d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNeXt101-64x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4) + return _create_resnet('resnext101_64x4d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def ecaresnet26t(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs an ECA-ResNeXt-26-T model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem and ECA attn. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet26t', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def ecaresnet50d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-50-D model with eca. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet50d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def ecaresnet50d_pruned(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-50-D model pruned with eca. + The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs)) + + +@register_model +def ecaresnet50t(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs an ECA-ResNet-50-T model. + Like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem and ECA attn. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet50t', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def ecaresnetlight(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-50-D light model with eca. + """ + model_args = dict( + block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True, + block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnetlight', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def ecaresnet101d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-101-D model with eca. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet101d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def ecaresnet101d_pruned(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-101-D model pruned with eca. + The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs)) + + +@register_model +def ecaresnet200d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-200-D model with ECA. + """ + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet200d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def ecaresnet269d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-269-D model with ECA. + """ + model_args = dict( + block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet269d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def ecaresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs an ECA-ResNeXt-26-T model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. This model replaces SE module with the ECA module + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnext26t_32x4d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def ecaresnext50t_32x4d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs an ECA-ResNeXt-50-T model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. This model replaces SE module with the ECA module + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnext50t_32x4d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnet18(pretrained: bool = False, **kwargs) -> ResNet: + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se')) + return _create_resnet('seresnet18', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnet34(pretrained: bool = False, **kwargs) -> ResNet: + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se')) + return _create_resnet('seresnet34', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnet50(pretrained: bool = False, **kwargs) -> ResNet: + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se')) + return _create_resnet('seresnet50', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnet50t(pretrained: bool = False, **kwargs) -> ResNet: + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', + avg_down=True, block_args=dict(attn_layer='se')) + return _create_resnet('seresnet50t', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnet101(pretrained: bool = False, **kwargs) -> ResNet: + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se')) + return _create_resnet('seresnet101', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnet152(pretrained: bool = False, **kwargs) -> ResNet: + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se')) + return _create_resnet('seresnet152', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnet152d(pretrained: bool = False, **kwargs) -> ResNet: + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', + avg_down=True, block_args=dict(attn_layer='se')) + return _create_resnet('seresnet152d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnet200d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-200-D model with SE attn. + """ + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', + avg_down=True, block_args=dict(attn_layer='se')) + return _create_resnet('seresnet200d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnet269d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-269-D model with SE attn. + """ + model_args = dict( + block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', + avg_down=True, block_args=dict(attn_layer='se')) + return _create_resnet('seresnet269d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnext26d_32x4d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a SE-ResNeXt-26-D model.` + This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for + combination of deep stem and avg_pool in downsample. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep', avg_down=True, block_args=dict(attn_layer='se')) + return _create_resnet('seresnext26d_32x4d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a SE-ResNet-26-T model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se')) + return _create_resnet('seresnext26t_32x4d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet: + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer='se')) + return _create_resnet('seresnext50_32x4d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet: + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer='se')) + return _create_resnet('seresnext101_32x4d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet: + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, + block_args=dict(attn_layer='se')) + return _create_resnet('seresnext101_32x8d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnext101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet: + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, + stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='se')) + return _create_resnet('seresnext101d_32x8d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet: + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, + block_args=dict(attn_layer='se')) + return _create_resnet('seresnext101_64x4d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def senet154(pretrained: bool = False, **kwargs) -> ResNet: + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', + down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se')) + return _create_resnet('senet154', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetblur18(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-18 model with blur anti-aliasing + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d) + return _create_resnet('resnetblur18', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetblur50(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-50 model with blur anti-aliasing + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d) + return _create_resnet('resnetblur50', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetblur50d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-50-D model with blur anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, + stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnetblur50d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetblur101d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-101-D model with blur anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d, + stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnetblur101d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetaa34d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-34-D model w/ avgpool anti-aliasing + """ + model_args = dict( + block=BasicBlock, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnetaa34d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetaa50(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-50 model with avgpool anti-aliasing + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d) + return _create_resnet('resnetaa50', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetaa50d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-50-D model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, + stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnetaa50d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetaa101d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-101-D model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d, + stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnetaa101d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnetaa50d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a SE=ResNet-50-D model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, + stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se')) + return _create_resnet('seresnetaa50d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnextaa101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, + stem_width=32, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d, + block_args=dict(attn_layer='se')) + return _create_resnet('seresnextaa101d_32x8d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def seresnextaa201d_32x8d(pretrained: bool = False, **kwargs): + """Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 4], cardinality=32, base_width=8, + stem_width=64, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d, + block_args=dict(attn_layer='se')) + return _create_resnet('seresnextaa201d_32x8d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetrs50(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-RS-50 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer)) + return _create_resnet('resnetrs50', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetrs101(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-RS-101 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer)) + return _create_resnet('resnetrs101', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetrs152(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-RS-152 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer)) + return _create_resnet('resnetrs152', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetrs200(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-RS-200 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer)) + return _create_resnet('resnetrs200', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetrs270(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-RS-270 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer)) + return _create_resnet('resnetrs270', pretrained, **dict(model_args, **kwargs)) + + + +@register_model +def resnetrs350(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-RS-350 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer)) + return _create_resnet('resnetrs350', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetrs420(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-RS-420 model + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer)) + return _create_resnet('resnetrs420', pretrained, **dict(model_args, **kwargs)) + + +register_model_deprecations(__name__, { + 'tv_resnet34': 'resnet34.tv_in1k', + 'tv_resnet50': 'resnet50.tv_in1k', + 'tv_resnet101': 'resnet101.tv_in1k', + 'tv_resnet152': 'resnet152.tv_in1k', + 'tv_resnext50_32x4d' : 'resnext50_32x4d.tv_in1k', + 'ig_resnext101_32x8d': 'resnext101_32x8d.fb_wsl_ig1b_ft_in1k', + 'ig_resnext101_32x16d': 'resnext101_32x8d.fb_wsl_ig1b_ft_in1k', + 'ig_resnext101_32x32d': 'resnext101_32x8d.fb_wsl_ig1b_ft_in1k', + 'ig_resnext101_32x48d': 'resnext101_32x8d.fb_wsl_ig1b_ft_in1k', + 'ssl_resnet18': 'resnet18.fb_ssl_yfcc100m_ft_in1k', + 'ssl_resnet50': 'resnet50.fb_ssl_yfcc100m_ft_in1k', + 'ssl_resnext50_32x4d': 'resnext50_32x4d.fb_ssl_yfcc100m_ft_in1k', + 'ssl_resnext101_32x4d': 'resnext101_32x4d.fb_ssl_yfcc100m_ft_in1k', + 'ssl_resnext101_32x8d': 'resnext101_32x8d.fb_ssl_yfcc100m_ft_in1k', + 'ssl_resnext101_32x16d': 'resnext101_32x16d.fb_ssl_yfcc100m_ft_in1k', + 'swsl_resnet18': 'resnet18.fb_swsl_ig1b_ft_in1k', + 'swsl_resnet50': 'resnet50.fb_swsl_ig1b_ft_in1k', + 'swsl_resnext50_32x4d': 'resnext50_32x4d.fb_swsl_ig1b_ft_in1k', + 'swsl_resnext101_32x4d': 'resnext101_32x4d.fb_swsl_ig1b_ft_in1k', + 'swsl_resnext101_32x8d': 'resnext101_32x8d.fb_swsl_ig1b_ft_in1k', + 'swsl_resnext101_32x16d': 'resnext101_32x16d.fb_swsl_ig1b_ft_in1k', + 'gluon_resnet18_v1b': 'resnet18.gluon_in1k', + 'gluon_resnet34_v1b': 'resnet34.gluon_in1k', + 'gluon_resnet50_v1b': 'resnet50.gluon_in1k', + 'gluon_resnet101_v1b': 'resnet101.gluon_in1k', + 'gluon_resnet152_v1b': 'resnet152.gluon_in1k', + 'gluon_resnet50_v1c': 'resnet50c.gluon_in1k', + 'gluon_resnet101_v1c': 'resnet101c.gluon_in1k', + 'gluon_resnet152_v1c': 'resnet152c.gluon_in1k', + 'gluon_resnet50_v1d': 'resnet50d.gluon_in1k', + 'gluon_resnet101_v1d': 'resnet101d.gluon_in1k', + 'gluon_resnet152_v1d': 'resnet152d.gluon_in1k', + 'gluon_resnet50_v1s': 'resnet50s.gluon_in1k', + 'gluon_resnet101_v1s': 'resnet101s.gluon_in1k', + 'gluon_resnet152_v1s': 'resnet152s.gluon_in1k', + 'gluon_resnext50_32x4d': 'resnext50_32x4d.gluon_in1k', + 'gluon_resnext101_32x4d': 'resnext101_32x4d.gluon_in1k', + 'gluon_resnext101_64x4d': 'resnext101_64x4d.gluon_in1k', + 'gluon_seresnext50_32x4d': 'seresnext50_32x4d.gluon_in1k', + 'gluon_seresnext101_32x4d': 'seresnext101_32x4d.gluon_in1k', + 'gluon_seresnext101_64x4d': 'seresnext101_64x4d.gluon_in1k', + 'gluon_senet154': 'senet154.gluon_in1k', + 'seresnext26tn_32x4d': 'seresnext26t_32x4d', +}) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/resnetv2.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/resnetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..63fb20332660fd25d0b5b05ee19e9e7d2a08abd8 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/resnetv2.py @@ -0,0 +1,773 @@ +"""Pre-Activation ResNet v2 with GroupNorm and Weight Standardization. + +A PyTorch implementation of ResNetV2 adapted from the Google Big-Transfer (BiT) source code +at https://github.com/google-research/big_transfer to match timm interfaces. The BiT weights have +been included here as pretrained models from their original .NPZ checkpoints. + +Additionally, supports non pre-activation bottleneck for use as a backbone for Vision Transfomers (ViT) and +extra padding support to allow porting of official Hybrid ResNet pretrained weights from +https://github.com/google-research/vision_transformer + +Thanks to the Google team for the above two repositories and associated papers: +* Big Transfer (BiT): General Visual Representation Learning - https://arxiv.org/abs/1912.11370 +* An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale - https://arxiv.org/abs/2010.11929 +* Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237 + +Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020. +""" +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict # pylint: disable=g-importing-member +from functools import partial + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dS0, FilterResponseNormTlu2d, ClassifierHead, \ + DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv +from ._registry import generate_default_cfgs, register_model, register_model_deprecations + +__all__ = ['ResNetV2'] # model_registry will add each entrypoint fn to this + + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + + Follows the implementation of "Identity Mappings in Deep Residual Networks": + https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua + + Except it puts the stride on 3x3 conv when available. + """ + + def __init__( + self, + in_chs, + out_chs=None, + bottle_ratio=0.25, + stride=1, + dilation=1, + first_dilation=None, + groups=1, + act_layer=None, + conv_layer=None, + norm_layer=None, + proj_layer=None, + drop_path_rate=0., + ): + super().__init__() + first_dilation = first_dilation or dilation + conv_layer = conv_layer or StdConv2d + norm_layer = norm_layer or partial(GroupNormAct, num_groups=32) + out_chs = out_chs or in_chs + mid_chs = make_divisible(out_chs * bottle_ratio) + + if proj_layer is not None: + self.downsample = proj_layer( + in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True, + conv_layer=conv_layer, norm_layer=norm_layer) + else: + self.downsample = None + + self.norm1 = norm_layer(in_chs) + self.conv1 = conv_layer(in_chs, mid_chs, 1) + self.norm2 = norm_layer(mid_chs) + self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) + self.norm3 = norm_layer(mid_chs) + self.conv3 = conv_layer(mid_chs, out_chs, 1) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + + def zero_init_last(self): + nn.init.zeros_(self.conv3.weight) + + def forward(self, x): + x_preact = self.norm1(x) + + # shortcut branch + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x_preact) + + # residual branch + x = self.conv1(x_preact) + x = self.conv2(self.norm2(x)) + x = self.conv3(self.norm3(x)) + x = self.drop_path(x) + return x + shortcut + + +class Bottleneck(nn.Module): + """Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT. + """ + def __init__( + self, + in_chs, + out_chs=None, + bottle_ratio=0.25, + stride=1, + dilation=1, + first_dilation=None, + groups=1, + act_layer=None, + conv_layer=None, + norm_layer=None, + proj_layer=None, + drop_path_rate=0., + ): + super().__init__() + first_dilation = first_dilation or dilation + act_layer = act_layer or nn.ReLU + conv_layer = conv_layer or StdConv2d + norm_layer = norm_layer or partial(GroupNormAct, num_groups=32) + out_chs = out_chs or in_chs + mid_chs = make_divisible(out_chs * bottle_ratio) + + if proj_layer is not None: + self.downsample = proj_layer( + in_chs, out_chs, stride=stride, dilation=dilation, preact=False, + conv_layer=conv_layer, norm_layer=norm_layer) + else: + self.downsample = None + + self.conv1 = conv_layer(in_chs, mid_chs, 1) + self.norm1 = norm_layer(mid_chs) + self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) + self.norm2 = norm_layer(mid_chs) + self.conv3 = conv_layer(mid_chs, out_chs, 1) + self.norm3 = norm_layer(out_chs, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.act3 = act_layer(inplace=True) + + def zero_init_last(self): + if getattr(self.norm3, 'weight', None) is not None: + nn.init.zeros_(self.norm3.weight) + + def forward(self, x): + # shortcut branch + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + # residual + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.conv3(x) + x = self.norm3(x) + x = self.drop_path(x) + x = self.act3(x + shortcut) + return x + + +class DownsampleConv(nn.Module): + def __init__( + self, + in_chs, + out_chs, + stride=1, + dilation=1, + first_dilation=None, + preact=True, + conv_layer=None, + norm_layer=None, + ): + super(DownsampleConv, self).__init__() + self.conv = conv_layer(in_chs, out_chs, 1, stride=stride) + self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) + + def forward(self, x): + return self.norm(self.conv(x)) + + +class DownsampleAvg(nn.Module): + def __init__( + self, + in_chs, + out_chs, + stride=1, + dilation=1, + first_dilation=None, + preact=True, + conv_layer=None, + norm_layer=None, + ): + """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" + super(DownsampleAvg, self).__init__() + avg_stride = stride if dilation == 1 else 1 + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + else: + self.pool = nn.Identity() + self.conv = conv_layer(in_chs, out_chs, 1, stride=1) + self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) + + def forward(self, x): + return self.norm(self.conv(self.pool(x))) + + +class ResNetStage(nn.Module): + """ResNet Stage.""" + def __init__( + self, + in_chs, + out_chs, + stride, + dilation, + depth, + bottle_ratio=0.25, + groups=1, + avg_down=False, + block_dpr=None, + block_fn=PreActBottleneck, + act_layer=None, + conv_layer=None, + norm_layer=None, + **block_kwargs, + ): + super(ResNetStage, self).__init__() + first_dilation = 1 if dilation in (1, 2) else 2 + layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer) + proj_layer = DownsampleAvg if avg_down else DownsampleConv + prev_chs = in_chs + self.blocks = nn.Sequential() + for block_idx in range(depth): + drop_path_rate = block_dpr[block_idx] if block_dpr else 0. + stride = stride if block_idx == 0 else 1 + self.blocks.add_module(str(block_idx), block_fn( + prev_chs, + out_chs, + stride=stride, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + first_dilation=first_dilation, + proj_layer=proj_layer, + drop_path_rate=drop_path_rate, + **layer_kwargs, + **block_kwargs, + )) + prev_chs = out_chs + first_dilation = dilation + proj_layer = None + + def forward(self, x): + x = self.blocks(x) + return x + + +def is_stem_deep(stem_type): + return any([s in stem_type for s in ('deep', 'tiered')]) + + +def create_resnetv2_stem( + in_chs, + out_chs=64, + stem_type='', + preact=True, + conv_layer=StdConv2d, + norm_layer=partial(GroupNormAct, num_groups=32), +): + stem = OrderedDict() + assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered') + + # NOTE conv padding mode can be changed by overriding the conv_layer def + if is_stem_deep(stem_type): + # A 3 deep 3x3 conv stack as in ResNet V1D models + if 'tiered' in stem_type: + stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py + else: + stem_chs = (out_chs // 2, out_chs // 2) # 'D' ResNets + stem['conv1'] = conv_layer(in_chs, stem_chs[0], kernel_size=3, stride=2) + stem['norm1'] = norm_layer(stem_chs[0]) + stem['conv2'] = conv_layer(stem_chs[0], stem_chs[1], kernel_size=3, stride=1) + stem['norm2'] = norm_layer(stem_chs[1]) + stem['conv3'] = conv_layer(stem_chs[1], out_chs, kernel_size=3, stride=1) + if not preact: + stem['norm3'] = norm_layer(out_chs) + else: + # The usual 7x7 stem conv + stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2) + if not preact: + stem['norm'] = norm_layer(out_chs) + + if 'fixed' in stem_type: + # 'fixed' SAME padding approximation that is used in BiT models + stem['pad'] = nn.ConstantPad2d(1, 0.) + stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) + elif 'same' in stem_type: + # full, input size based 'SAME' padding, used in ViT Hybrid model + stem['pool'] = create_pool2d('max', kernel_size=3, stride=2, padding='same') + else: + # the usual PyTorch symmetric padding + stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + return nn.Sequential(stem) + + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode. + """ + + def __init__( + self, + layers, + channels=(256, 512, 1024, 2048), + num_classes=1000, + in_chans=3, + global_pool='avg', + output_stride=32, + width_factor=1, + stem_chs=64, + stem_type='', + avg_down=False, + preact=True, + act_layer=nn.ReLU, + norm_layer=partial(GroupNormAct, num_groups=32), + conv_layer=StdConv2d, + drop_rate=0., + drop_path_rate=0., + zero_init_last=False, + ): + """ + Args: + layers (List[int]) : number of layers in each block + channels (List[int]) : number of channels in each block: + num_classes (int): number of classification classes (default 1000) + in_chans (int): number of input (color) channels. (default 3) + global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg') + output_stride (int): output stride of the network, 32, 16, or 8. (default 32) + width_factor (int): channel (width) multiplication factor + stem_chs (int): stem width (default: 64) + stem_type (str): stem type (default: '' == 7x7) + avg_down (bool): average pooling in residual downsampling (default: False) + preact (bool): pre-activiation (default: True) + act_layer (Union[str, nn.Module]): activation layer + norm_layer (Union[str, nn.Module]): normalization layer + conv_layer (nn.Module): convolution module + drop_rate: classifier dropout rate (default: 0.) + drop_path_rate: stochastic depth rate (default: 0.) + zero_init_last: zero-init last weight in residual path (default: False) + """ + super().__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + wf = width_factor + norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer) + act_layer = get_act_layer(act_layer) + + self.feature_info = [] + stem_chs = make_divisible(stem_chs * wf) + self.stem = create_resnetv2_stem( + in_chans, + stem_chs, + stem_type, + preact, + conv_layer=conv_layer, + norm_layer=norm_layer, + ) + stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm' + self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat)) + + prev_chs = stem_chs + curr_stride = 4 + dilation = 1 + block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] + block_fn = PreActBottleneck if preact else Bottleneck + self.stages = nn.Sequential() + for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)): + out_chs = make_divisible(c * wf) + stride = 1 if stage_idx == 0 else 2 + if curr_stride >= output_stride: + dilation *= stride + stride = 1 + stage = ResNetStage( + prev_chs, + out_chs, + stride=stride, + dilation=dilation, + depth=d, + avg_down=avg_down, + act_layer=act_layer, + conv_layer=conv_layer, + norm_layer=norm_layer, + block_dpr=bdpr, + block_fn=block_fn, + ) + prev_chs = out_chs + curr_stride *= stride + self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')] + self.stages.add_module(str(stage_idx), stage) + + self.num_features = prev_chs + self.norm = norm_layer(self.num_features) if preact else nn.Identity() + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + use_conv=True, + ) + + self.init_weights(zero_init_last=zero_init_last) + self.grad_checkpointing = False + + @torch.jit.ignore + def init_weights(self, zero_init_last=True): + named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix='resnet/'): + _load_weights(self, checkpoint_path, prefix) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^norm', (99999,)) + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) + + def forward_features(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x, flatten=True) + else: + x = self.stages(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _init_weights(module: nn.Module, name: str = '', zero_init_last=True): + if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)): + nn.init.normal_(module.weight, mean=0.0, std=0.01) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif zero_init_last and hasattr(module, 'zero_init_last'): + module.zero_init_last() + + +@torch.no_grad() +def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/'): + import numpy as np + + def t2p(conv_weights): + """Possibly convert HWIO to OIHW.""" + if conv_weights.ndim == 4: + conv_weights = conv_weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(conv_weights) + + weights = np.load(checkpoint_path) + stem_conv_w = adapt_input_conv( + model.stem.conv.weight.shape[1], t2p(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) + model.stem.conv.weight.copy_(stem_conv_w) + model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma'])) + model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta'])) + if isinstance(getattr(model.head, 'fc', None), nn.Conv2d) and \ + model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]: + model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel'])) + model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias'])) + for i, (sname, stage) in enumerate(model.stages.named_children()): + for j, (bname, block) in enumerate(stage.blocks.named_children()): + cname = 'standardized_conv2d' + block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/' + block.conv1.weight.copy_(t2p(weights[f'{block_prefix}a/{cname}/kernel'])) + block.conv2.weight.copy_(t2p(weights[f'{block_prefix}b/{cname}/kernel'])) + block.conv3.weight.copy_(t2p(weights[f'{block_prefix}c/{cname}/kernel'])) + block.norm1.weight.copy_(t2p(weights[f'{block_prefix}a/group_norm/gamma'])) + block.norm2.weight.copy_(t2p(weights[f'{block_prefix}b/group_norm/gamma'])) + block.norm3.weight.copy_(t2p(weights[f'{block_prefix}c/group_norm/gamma'])) + block.norm1.bias.copy_(t2p(weights[f'{block_prefix}a/group_norm/beta'])) + block.norm2.bias.copy_(t2p(weights[f'{block_prefix}b/group_norm/beta'])) + block.norm3.bias.copy_(t2p(weights[f'{block_prefix}c/group_norm/beta'])) + if block.downsample is not None: + w = weights[f'{block_prefix}a/proj/{cname}/kernel'] + block.downsample.conv.weight.copy_(t2p(w)) + + +def _create_resnetv2(variant, pretrained=False, **kwargs): + feature_cfg = dict(flatten_sequential=True) + return build_model_with_cfg( + ResNetV2, variant, pretrained, + feature_cfg=feature_cfg, + **kwargs, + ) + + +def _create_resnetv2_bit(variant, pretrained=False, **kwargs): + return _create_resnetv2( + variant, + pretrained=pretrained, + stem_type='fixed', + conv_layer=partial(StdConv2d, eps=1e-8), + **kwargs, + ) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + # Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237 + 'resnetv2_50x1_bit.goog_distilled_in1k': _cfg( + hf_hub_id='timm/', + interpolation='bicubic', custom_load=True), + 'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k': _cfg( + hf_hub_id='timm/', + interpolation='bicubic', custom_load=True), + 'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic', custom_load=True), + + # pretrained on imagenet21k, finetuned on imagenet1k + 'resnetv2_50x1_bit.goog_in21k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True), + 'resnetv2_50x3_bit.goog_in21k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True), + 'resnetv2_101x1_bit.goog_in21k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True), + 'resnetv2_101x3_bit.goog_in21k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True), + 'resnetv2_152x2_bit.goog_in21k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True), + 'resnetv2_152x4_bit.goog_in21k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0, custom_load=True), # only one at 480x480? + + # trained on imagenet-21k + 'resnetv2_50x1_bit.goog_in21k': _cfg( + hf_hub_id='timm/', + num_classes=21843, custom_load=True), + 'resnetv2_50x3_bit.goog_in21k': _cfg( + hf_hub_id='timm/', + num_classes=21843, custom_load=True), + 'resnetv2_101x1_bit.goog_in21k': _cfg( + hf_hub_id='timm/', + num_classes=21843, custom_load=True), + 'resnetv2_101x3_bit.goog_in21k': _cfg( + hf_hub_id='timm/', + num_classes=21843, custom_load=True), + 'resnetv2_152x2_bit.goog_in21k': _cfg( + hf_hub_id='timm/', + num_classes=21843, custom_load=True), + 'resnetv2_152x4_bit.goog_in21k': _cfg( + hf_hub_id='timm/', + num_classes=21843, custom_load=True), + + 'resnetv2_50.a1h_in1k': _cfg( + hf_hub_id='timm/', + interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'resnetv2_50d.untrained': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_50t.untrained': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_101.a1h_in1k': _cfg( + hf_hub_id='timm/', + interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'resnetv2_101d.untrained': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_152.untrained': _cfg( + interpolation='bicubic'), + 'resnetv2_152d.untrained': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + + 'resnetv2_50d_gn.ah_in1k': _cfg( + hf_hub_id='timm/', + interpolation='bicubic', first_conv='stem.conv1', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'resnetv2_50d_evos.ah_in1k': _cfg( + hf_hub_id='timm/', + interpolation='bicubic', first_conv='stem.conv1', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'resnetv2_50d_frn.untrained': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), +}) + + +@register_model +def resnetv2_50x1_bit(pretrained=False, **kwargs) -> ResNetV2: + return _create_resnetv2_bit( + 'resnetv2_50x1_bit', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs) + + +@register_model +def resnetv2_50x3_bit(pretrained=False, **kwargs) -> ResNetV2: + return _create_resnetv2_bit( + 'resnetv2_50x3_bit', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs) + + +@register_model +def resnetv2_101x1_bit(pretrained=False, **kwargs) -> ResNetV2: + return _create_resnetv2_bit( + 'resnetv2_101x1_bit', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs) + + +@register_model +def resnetv2_101x3_bit(pretrained=False, **kwargs) -> ResNetV2: + return _create_resnetv2_bit( + 'resnetv2_101x3_bit', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs) + + +@register_model +def resnetv2_152x2_bit(pretrained=False, **kwargs) -> ResNetV2: + return _create_resnetv2_bit( + 'resnetv2_152x2_bit', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs) + + +@register_model +def resnetv2_152x4_bit(pretrained=False, **kwargs) -> ResNetV2: + return _create_resnetv2_bit( + 'resnetv2_152x4_bit', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs) + + +@register_model +def resnetv2_50(pretrained=False, **kwargs) -> ResNetV2: + model_args = dict(layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d) + return _create_resnetv2('resnetv2_50', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetv2_50d(pretrained=False, **kwargs) -> ResNetV2: + model_args = dict( + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, + stem_type='deep', avg_down=True) + return _create_resnetv2('resnetv2_50d', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetv2_50t(pretrained=False, **kwargs) -> ResNetV2: + model_args = dict( + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, + stem_type='tiered', avg_down=True) + return _create_resnetv2('resnetv2_50t', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetv2_101(pretrained=False, **kwargs) -> ResNetV2: + model_args = dict(layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d) + return _create_resnetv2('resnetv2_101', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetv2_101d(pretrained=False, **kwargs) -> ResNetV2: + model_args = dict( + layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, + stem_type='deep', avg_down=True) + return _create_resnetv2('resnetv2_101d', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetv2_152(pretrained=False, **kwargs) -> ResNetV2: + model_args = dict(layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d) + return _create_resnetv2('resnetv2_152', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetv2_152d(pretrained=False, **kwargs) -> ResNetV2: + model_args = dict( + layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, + stem_type='deep', avg_down=True) + return _create_resnetv2('resnetv2_152d', pretrained=pretrained, **dict(model_args, **kwargs)) + + +# Experimental configs (may change / be removed) + +@register_model +def resnetv2_50d_gn(pretrained=False, **kwargs) -> ResNetV2: + model_args = dict( + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=GroupNormAct, + stem_type='deep', avg_down=True) + return _create_resnetv2('resnetv2_50d_gn', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetv2_50d_evos(pretrained=False, **kwargs) -> ResNetV2: + model_args = dict( + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0, + stem_type='deep', avg_down=True) + return _create_resnetv2('resnetv2_50d_evos', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetv2_50d_frn(pretrained=False, **kwargs) -> ResNetV2: + model_args = dict( + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d, + stem_type='deep', avg_down=True) + return _create_resnetv2('resnetv2_50d_frn', pretrained=pretrained, **dict(model_args, **kwargs)) + + +register_model_deprecations(__name__, { + 'resnetv2_50x1_bitm': 'resnetv2_50x1_bit.goog_in21k_ft_in1k', + 'resnetv2_50x3_bitm': 'resnetv2_50x3_bit.goog_in21k_ft_in1k', + 'resnetv2_101x1_bitm': 'resnetv2_101x1_bit.goog_in21k_ft_in1k', + 'resnetv2_101x3_bitm': 'resnetv2_101x3_bit.goog_in21k_ft_in1k', + 'resnetv2_152x2_bitm': 'resnetv2_152x2_bit.goog_in21k_ft_in1k', + 'resnetv2_152x4_bitm': 'resnetv2_152x4_bit.goog_in21k_ft_in1k', + 'resnetv2_50x1_bitm_in21k': 'resnetv2_50x1_bit.goog_in21k', + 'resnetv2_50x3_bitm_in21k': 'resnetv2_50x3_bit.goog_in21k', + 'resnetv2_101x1_bitm_in21k': 'resnetv2_101x1_bit.goog_in21k', + 'resnetv2_101x3_bitm_in21k': 'resnetv2_101x3_bit.goog_in21k', + 'resnetv2_152x2_bitm_in21k': 'resnetv2_152x2_bit.goog_in21k', + 'resnetv2_152x4_bitm_in21k': 'resnetv2_152x4_bit.goog_in21k', + 'resnetv2_50x1_bit_distilled': 'resnetv2_50x1_bit.goog_distilled_in1k', + 'resnetv2_152x2_bit_teacher': 'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k', + 'resnetv2_152x2_bit_teacher_384': 'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k_384', +}) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/senet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/senet.py new file mode 100644 index 0000000000000000000000000000000000000000..8b203c372c39491477603ed43ca23dda93b6fead --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/senet.py @@ -0,0 +1,465 @@ +""" +SEResNet implementation from Cadene's pretrained models +https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py +Additional credit to https://github.com/creafz + +Original model: https://github.com/hujie-frank/SENet + +ResNet code gently borrowed from +https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py + +FIXME I'm deprecating this model and moving them to ResNet as I don't want to maintain duplicate +support for extras like dilation, switchable BN/activations, feature extraction, etc that don't exist here. +""" +import math +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import create_classifier +from ._builder import build_model_with_cfg +from ._registry import register_model, generate_default_cfgs + +__all__ = ['SENet'] + + +def _weight_init(m): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + + +class SEModule(nn.Module): + + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + module_input = x + x = x.mean((2, 3), keepdim=True) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class Bottleneck(nn.Module): + """ + Base class for bottlenecks that implements `forward()` method. + """ + + def forward(self, x): + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + shortcut = self.downsample(x) + + out = self.se_module(out) + shortcut + out = self.relu(out) + + return out + + +class SEBottleneck(Bottleneck): + """ + Bottleneck for SENet154. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): + super(SEBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes * 2) + self.conv2 = nn.Conv2d( + planes * 2, planes * 4, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes * 4) + self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNetBottleneck(Bottleneck): + """ + ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe + implementation and uses `stride=stride` in `conv1` and not in `conv2` + (the latter is used in the torchvision implementation of ResNet). + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): + super(SEResNetBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, stride=stride) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNeXtBottleneck(Bottleneck): + """ + ResNeXt bottleneck type C with a Squeeze-and-Excitation module. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None, base_width=4): + super(SEResNeXtBottleneck, self).__init__() + width = math.floor(planes * (base_width / 64)) * groups + self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNetBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): + super(SEResNetBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes, reduction=reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + if self.downsample is not None: + shortcut = self.downsample(x) + + out = self.se_module(out) + shortcut + out = self.relu(out) + + return out + + +class SENet(nn.Module): + + def __init__( + self, block, layers, groups, reduction, drop_rate=0.2, + in_chans=3, inplanes=64, input_3x3=False, downsample_kernel_size=1, + downsample_padding=0, num_classes=1000, global_pool='avg'): + """ + Parameters + ---------- + block (nn.Module): Bottleneck class. + - For SENet154: SEBottleneck + - For SE-ResNet models: SEResNetBottleneck + - For SE-ResNeXt models: SEResNeXtBottleneck + layers (list of ints): Number of residual blocks for 4 layers of the + network (layer1...layer4). + groups (int): Number of groups for the 3x3 convolution in each + bottleneck block. + - For SENet154: 64 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 32 + reduction (int): Reduction ratio for Squeeze-and-Excitation modules. + - For all models: 16 + dropout_p (float or None): Drop probability for the Dropout layer. + If `None` the Dropout layer is not used. + - For SENet154: 0.2 + - For SE-ResNet models: None + - For SE-ResNeXt models: None + inplanes (int): Number of input channels for layer1. + - For SENet154: 128 + - For SE-ResNet models: 64 + - For SE-ResNeXt models: 64 + input_3x3 (bool): If `True`, use three 3x3 convolutions instead of + a single 7x7 convolution in layer0. + - For SENet154: True + - For SE-ResNet models: False + - For SE-ResNeXt models: False + downsample_kernel_size (int): Kernel size for downsampling convolutions + in layer2, layer3 and layer4. + - For SENet154: 3 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 1 + downsample_padding (int): Padding for downsampling convolutions in + layer2, layer3 and layer4. + - For SENet154: 1 + - For SE-ResNet models: 0 + - For SE-ResNeXt models: 0 + num_classes (int): Number of outputs in `last_linear` layer. + - For all models: 1000 + """ + super(SENet, self).__init__() + self.inplanes = inplanes + self.num_classes = num_classes + self.drop_rate = drop_rate + if input_3x3: + layer0_modules = [ + ('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False)), + ('bn1', nn.BatchNorm2d(64)), + ('relu1', nn.ReLU(inplace=True)), + ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), + ('bn2', nn.BatchNorm2d(64)), + ('relu2', nn.ReLU(inplace=True)), + ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False)), + ('bn3', nn.BatchNorm2d(inplanes)), + ('relu3', nn.ReLU(inplace=True)), + ] + else: + layer0_modules = [ + ('conv1', nn.Conv2d( + in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), + ('bn1', nn.BatchNorm2d(inplanes)), + ('relu1', nn.ReLU(inplace=True)), + ] + self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) + # To preserve compatibility with Caffe weights `ceil_mode=True` is used instead of `padding=1`. + self.pool0 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.feature_info = [dict(num_chs=inplanes, reduction=2, module='layer0')] + self.layer1 = self._make_layer( + block, + planes=64, + blocks=layers[0], + groups=groups, + reduction=reduction, + downsample_kernel_size=1, + downsample_padding=0 + ) + self.feature_info += [dict(num_chs=64 * block.expansion, reduction=4, module='layer1')] + self.layer2 = self._make_layer( + block, + planes=128, + blocks=layers[1], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.feature_info += [dict(num_chs=128 * block.expansion, reduction=8, module='layer2')] + self.layer3 = self._make_layer( + block, + planes=256, + blocks=layers[2], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.feature_info += [dict(num_chs=256 * block.expansion, reduction=16, module='layer3')] + self.layer4 = self._make_layer( + block, + planes=512, + blocks=layers[3], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.feature_info += [dict(num_chs=512 * block.expansion, reduction=32, module='layer4')] + self.num_features = 512 * block.expansion + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + for m in self.modules(): + _weight_init(m) + + def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, + downsample_kernel_size=1, downsample_padding=0): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, planes * block.expansion, kernel_size=downsample_kernel_size, + stride=stride, padding=downsample_padding, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, groups, reduction, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups, reduction)) + + return nn.Sequential(*layers) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict(stem=r'^layer0', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)') + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.layer0(x) + x = self.pool0(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return x if pre_logits else self.last_linear(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_senet(variant, pretrained=False, **kwargs): + return build_model_with_cfg(SENet, variant, pretrained, **kwargs) + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'layer0.conv1', 'classifier': 'last_linear', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'legacy_senet154.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_senet154-e9eb9fe6.pth'), + 'legacy_seresnet18.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet18-4bb0ce65.pth', + interpolation='bicubic'), + 'legacy_seresnet34.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet34-a4004e63.pth'), + 'legacy_seresnet50.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'), + 'legacy_seresnet101.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'), + 'legacy_seresnet152.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'), + 'legacy_seresnext26_32x4d.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth', + interpolation='bicubic'), + 'legacy_seresnext50_32x4d.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_se_resnext50_32x4d-f3651bad.pth'), + 'legacy_seresnext101_32x4d.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_se_resnext101_32x4d-37725eac.pth'), +}) + + +@register_model +def legacy_seresnet18(pretrained=False, **kwargs) -> SENet: + model_args = dict( + block=SEResNetBlock, layers=[2, 2, 2, 2], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet18', pretrained, **model_args) + + +@register_model +def legacy_seresnet34(pretrained=False, **kwargs) -> SENet: + model_args = dict( + block=SEResNetBlock, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet34', pretrained, **model_args) + + +@register_model +def legacy_seresnet50(pretrained=False, **kwargs) -> SENet: + model_args = dict( + block=SEResNetBottleneck, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet50', pretrained, **model_args) + + +@register_model +def legacy_seresnet101(pretrained=False, **kwargs) -> SENet: + model_args = dict( + block=SEResNetBottleneck, layers=[3, 4, 23, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet101', pretrained, **model_args) + + +@register_model +def legacy_seresnet152(pretrained=False, **kwargs) -> SENet: + model_args = dict( + block=SEResNetBottleneck, layers=[3, 8, 36, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet152', pretrained, **model_args) + + +@register_model +def legacy_senet154(pretrained=False, **kwargs) -> SENet: + model_args = dict( + block=SEBottleneck, layers=[3, 8, 36, 3], groups=64, reduction=16, + downsample_kernel_size=3, downsample_padding=1, inplanes=128, input_3x3=True, **kwargs) + return _create_senet('legacy_senet154', pretrained, **model_args) + + +@register_model +def legacy_seresnext26_32x4d(pretrained=False, **kwargs) -> SENet: + model_args = dict( + block=SEResNeXtBottleneck, layers=[2, 2, 2, 2], groups=32, reduction=16, **kwargs) + return _create_senet('legacy_seresnext26_32x4d', pretrained, **model_args) + + +@register_model +def legacy_seresnext50_32x4d(pretrained=False, **kwargs) -> SENet: + model_args = dict( + block=SEResNeXtBottleneck, layers=[3, 4, 6, 3], groups=32, reduction=16, **kwargs) + return _create_senet('legacy_seresnext50_32x4d', pretrained, **model_args) + + +@register_model +def legacy_seresnext101_32x4d(pretrained=False, **kwargs) -> SENet: + model_args = dict( + block=SEResNeXtBottleneck, layers=[3, 4, 23, 3], groups=32, reduction=16, **kwargs) + return _create_senet('legacy_seresnext101_32x4d', pretrained, **model_args) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/swin_transformer.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..bb3f9508b98eab31c9145ef879f8e39417cf3eb9 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/swin_transformer.py @@ -0,0 +1,851 @@ +""" Swin Transformer +A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` + - https://arxiv.org/pdf/2103.14030 + +Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below + +S3 (AutoFormerV2, https://arxiv.org/abs/2111.14725) Swin weights from + - https://github.com/microsoft/Cream/tree/main/AutoFormerV2 + +Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +""" +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- +import logging +import math +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \ + _assert, use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed, ndgrid +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq, named_apply +from ._registry import generate_default_cfgs, register_model, register_model_deprecations +from .vision_transformer import get_init_weights_vit + +__all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this + +_logger = logging.getLogger(__name__) + +_int_or_tuple_2_t = Union[int, Tuple[int, int]] + + +def window_partition( + x: torch.Tensor, + window_size: Tuple[int, int], +) -> torch.Tensor: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + C = windows.shape[-1] + x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) + return x + + +def get_relative_position_index(win_h: int, win_w: int): + # get pair-wise relative position index for each token inside the window + coords = torch.stack(ndgrid(torch.arange(win_h), torch.arange(win_w))) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += win_h - 1 # shift to start from 0 + relative_coords[:, :, 1] += win_w - 1 + relative_coords[:, :, 0] *= 2 * win_w - 1 + return relative_coords.sum(-1) # Wh*Ww, Wh*Ww + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports shifted and non-shifted windows. + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim: int, + num_heads: int, + head_dim: Optional[int] = None, + window_size: _int_or_tuple_2_t = 7, + qkv_bias: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + ): + """ + Args: + dim: Number of input channels. + num_heads: Number of attention heads. + head_dim: Number of channels per head (dim // num_heads if not set) + window_size: The height and width of the window. + qkv_bias: If True, add a learnable bias to query, key, value. + attn_drop: Dropout ratio of attention weight. + proj_drop: Dropout ratio of output. + """ + super().__init__() + self.dim = dim + self.window_size = to_2tuple(window_size) # Wh, Ww + win_h, win_w = self.window_size + self.window_area = win_h * win_w + self.num_heads = num_heads + head_dim = head_dim or dim // num_heads + attn_dim = head_dim * num_heads + self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn(experimental=True) # NOTE not tested for prime-time yet + + # define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH + self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)) + + # get pair-wise relative position index for each token inside the window + self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False) + + self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(attn_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def _get_rel_pos_bias(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + return relative_position_bias.unsqueeze(0) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + if self.fused_attn: + attn_mask = self._get_rel_pos_bias() + if mask is not None: + num_win = mask.shape[0] + mask = mask.view(1, num_win, 1, N, N).expand(B_ // num_win, -1, self.num_heads, -1, -1) + attn_mask = attn_mask + mask.reshape(-1, self.num_heads, N, N) + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn + self._get_rel_pos_bias() + if mask is not None: + num_win = mask.shape[0] + attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B_, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + """ + + def __init__( + self, + dim: int, + input_resolution: _int_or_tuple_2_t, + num_heads: int = 4, + head_dim: Optional[int] = None, + window_size: _int_or_tuple_2_t = 7, + shift_size: int = 0, + mlp_ratio: float = 4., + qkv_bias: bool = True, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ): + """ + Args: + dim: Number of input channels. + input_resolution: Input resolution. + window_size: Window size. + num_heads: Number of attention heads. + head_dim: Enforce the number of channels per head + shift_size: Shift size for SW-MSA. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + proj_drop: Dropout rate. + attn_drop: Attention dropout rate. + drop_path: Stochastic depth rate. + act_layer: Activation layer. + norm_layer: Normalization layer. + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + ws, ss = self._calc_window_shift(window_size, shift_size) + self.window_size: Tuple[int, int] = ws + self.shift_size: Tuple[int, int] = ss + self.window_area = self.window_size[0] * self.window_size[1] + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + num_heads=num_heads, + head_dim=head_dim, + window_size=to_2tuple(self.window_size), + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + if any(self.shift_size): + # calculate attention mask for SW-MSA + H, W = self.input_resolution + H = math.ceil(H / self.window_size[0]) * self.window_size[0] + W = math.ceil(W / self.window_size[1]) * self.window_size[1] + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + cnt = 0 + for h in ( + slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)): + for w in ( + slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)): + img_mask[:, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_area) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask, persistent=False) + + def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: + target_window_size = to_2tuple(target_window_size) + target_shift_size = to_2tuple(target_shift_size) + window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] + shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] + return tuple(window_size), tuple(shift_size) + + def _attn(self, x): + B, H, W, C = x.shape + + # cyclic shift + has_shift = any(self.shift_size) + if has_shift: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) + else: + shifted_x = x + + # pad for resolution not divisible by window size + pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + shifted_x = shifted_x[:, :H, :W, :].contiguous() + + # reverse cyclic shift + if has_shift: + x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2)) + else: + x = shifted_x + return x + + def forward(self, x): + B, H, W, C = x.shape + x = x + self.drop_path1(self._attn(self.norm1(x))) + x = x.reshape(B, -1, C) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + x = x.reshape(B, H, W, C) + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer. + """ + + def __init__( + self, + dim: int, + out_dim: Optional[int] = None, + norm_layer: Callable = nn.LayerNorm, + ): + """ + Args: + dim: Number of input channels. + out_dim: Number of output channels (or 2 * dim if None) + norm_layer: Normalization layer. + """ + super().__init__() + self.dim = dim + self.out_dim = out_dim or 2 * dim + self.norm = norm_layer(4 * dim) + self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False) + + def forward(self, x): + B, H, W, C = x.shape + _assert(H % 2 == 0, f"x height ({H}) is not even.") + _assert(W % 2 == 0, f"x width ({W}) is not even.") + x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3) + x = self.norm(x) + x = self.reduction(x) + return x + + +class SwinTransformerStage(nn.Module): + """ A basic Swin Transformer layer for one stage. + """ + + def __init__( + self, + dim: int, + out_dim: int, + input_resolution: Tuple[int, int], + depth: int, + downsample: bool = True, + num_heads: int = 4, + head_dim: Optional[int] = None, + window_size: _int_or_tuple_2_t = 7, + mlp_ratio: float = 4., + qkv_bias: bool = True, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: Union[List[float], float] = 0., + norm_layer: Callable = nn.LayerNorm, + ): + """ + Args: + dim: Number of input channels. + out_dim: Number of output channels. + input_resolution: Input resolution. + depth: Number of blocks. + downsample: Downsample layer at the end of the layer. + num_heads: Number of attention heads. + head_dim: Channels per head (dim // num_heads if not set) + window_size: Local window size. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + proj_drop: Projection dropout rate. + attn_drop: Attention dropout rate. + drop_path: Stochastic depth rate. + norm_layer: Normalization layer. + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution + self.depth = depth + self.grad_checkpointing = False + window_size = to_2tuple(window_size) + shift_size = tuple([w // 2 for w in window_size]) + + # patch merging layer + if downsample: + self.downsample = PatchMerging( + dim=dim, + out_dim=out_dim, + norm_layer=norm_layer, + ) + else: + assert dim == out_dim + self.downsample = nn.Identity() + + # build blocks + self.blocks = nn.Sequential(*[ + SwinTransformerBlock( + dim=out_dim, + input_resolution=self.output_resolution, + num_heads=num_heads, + head_dim=head_dim, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop=proj_drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth)]) + + def forward(self, x): + x = self.downsample(x) + + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class SwinTransformer(nn.Module): + """ Swin Transformer + + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + """ + + def __init__( + self, + img_size: _int_or_tuple_2_t = 224, + patch_size: int = 4, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + embed_dim: int = 96, + depths: Tuple[int, ...] = (2, 2, 6, 2), + num_heads: Tuple[int, ...] = (3, 6, 12, 24), + head_dim: Optional[int] = None, + window_size: _int_or_tuple_2_t = 7, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0.1, + embed_layer: Callable = PatchEmbed, + norm_layer: Union[str, Callable] = nn.LayerNorm, + weight_init: str = '', + **kwargs, + ): + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of input image channels. + num_classes: Number of classes for classification head. + embed_dim: Patch embedding dimension. + depths: Depth of each Swin Transformer layer. + num_heads: Number of attention heads in different layers. + head_dim: Dimension of self-attention heads. + window_size: Window size. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + drop_rate: Dropout rate. + attn_drop_rate (float): Attention dropout rate. + drop_path_rate (float): Stochastic depth rate. + embed_layer: Patch embedding layer. + norm_layer (nn.Module): Normalization layer. + """ + super().__init__() + assert global_pool in ('', 'avg') + self.num_classes = num_classes + self.global_pool = global_pool + self.output_fmt = 'NHWC' + + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.feature_info = [] + + if not isinstance(embed_dim, (tuple, list)): + embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + + # split image into non-overlapping patches + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim[0], + norm_layer=norm_layer, + output_fmt='NHWC', + ) + self.patch_grid = self.patch_embed.grid_size + + # build layers + head_dim = to_ntuple(self.num_layers)(head_dim) + if not isinstance(window_size, (list, tuple)): + window_size = to_ntuple(self.num_layers)(window_size) + elif len(window_size) == 2: + window_size = (window_size,) * self.num_layers + assert len(window_size) == self.num_layers + mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio) + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + layers = [] + in_dim = embed_dim[0] + scale = 1 + for i in range(self.num_layers): + out_dim = embed_dim[i] + layers += [SwinTransformerStage( + dim=in_dim, + out_dim=out_dim, + input_resolution=( + self.patch_grid[0] // scale, + self.patch_grid[1] // scale + ), + depth=depths[i], + downsample=i > 0, + num_heads=num_heads[i], + head_dim=head_dim[i], + window_size=window_size[i], + mlp_ratio=mlp_ratio[i], + qkv_bias=qkv_bias, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + )] + in_dim = out_dim + if i > 0: + scale *= 2 + self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')] + self.layers = nn.Sequential(*layers) + + self.norm = norm_layer(self.num_features) + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + input_fmt=self.output_fmt, + ) + if weight_init != 'skip': + self.init_weights(weight_init) + + @torch.jit.ignore + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'moco', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + named_apply(get_init_weights_vit(mode, head_bias=head_bias), self) + + @torch.jit.ignore + def no_weight_decay(self): + nwd = set() + for n, _ in self.named_parameters(): + if 'relative_position_bias_table' in n: + nwd.add(n) + return nwd + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^patch_embed', # stem and embed + blocks=r'^layers\.(\d+)' if coarse else [ + (r'^layers\.(\d+).downsample', (0,)), + (r'^layers\.(\d+)\.\w+\.(\d+)', None), + (r'^norm', (99999,)), + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for l in self.layers: + l.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + self.head.reset(num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.patch_embed(x) + x = self.layers(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=True) if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + old_weights = True + if 'head.fc.weight' in state_dict: + old_weights = False + import re + out_dict = {} + state_dict = state_dict.get('model', state_dict) + state_dict = state_dict.get('state_dict', state_dict) + for k, v in state_dict.items(): + if any([n in k for n in ('relative_position_index', 'attn_mask')]): + continue # skip buffers that should not be persistent + + if 'patch_embed.proj.weight' in k: + _, _, H, W = model.patch_embed.proj.weight.shape + if v.shape[-2] != H or v.shape[-1] != W: + v = resample_patch_embed( + v, + (H, W), + interpolation='bicubic', + antialias=True, + verbose=True, + ) + + if k.endswith('relative_position_bias_table'): + m = model.get_submodule(k[:-29]) + if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]: + v = resize_rel_pos_bias_table( + v, + new_window_size=m.window_size, + new_bias_shape=m.relative_position_bias_table.shape, + ) + + if old_weights: + k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) + k = k.replace('head.', 'head.fc.') + + out_dict[k] = v + return out_dict + + +def _create_swin_transformer(variant, pretrained=False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + + model = build_model_with_cfg( + SwinTransformer, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs) + + return model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', + 'license': 'mit', **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'swin_small_patch4_window7_224.ms_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth', ), + 'swin_base_patch4_window7_224.ms_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',), + 'swin_base_patch4_window12_384.ms_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'swin_large_patch4_window7_224.ms_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',), + 'swin_large_patch4_window12_384.ms_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + + 'swin_tiny_patch4_window7_224.ms_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',), + 'swin_small_patch4_window7_224.ms_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',), + 'swin_base_patch4_window7_224.ms_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth',), + 'swin_base_patch4_window12_384.ms_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + + # tiny 22k pretrain is worse than 1k, so moved after (untagged priority is based on order) + 'swin_tiny_patch4_window7_224.ms_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth',), + + 'swin_tiny_patch4_window7_224.ms_in22k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth', + num_classes=21841), + 'swin_small_patch4_window7_224.ms_in22k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth', + num_classes=21841), + 'swin_base_patch4_window7_224.ms_in22k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', + num_classes=21841), + 'swin_base_patch4_window12_384.ms_in22k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841), + 'swin_large_patch4_window7_224.ms_in22k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth', + num_classes=21841), + 'swin_large_patch4_window12_384.ms_in22k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841), + + 'swin_s3_tiny_224.ms_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'), + 'swin_s3_small_224.ms_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'), + 'swin_s3_base_224.ms_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'), +}) + + +@register_model +def swin_tiny_patch4_window7_224(pretrained=False, **kwargs) -> SwinTransformer: + """ Swin-T @ 224x224, trained ImageNet-1k + """ + model_args = dict(patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24)) + return _create_swin_transformer( + 'swin_tiny_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swin_small_patch4_window7_224(pretrained=False, **kwargs) -> SwinTransformer: + """ Swin-S @ 224x224 + """ + model_args = dict(patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24)) + return _create_swin_transformer( + 'swin_small_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swin_base_patch4_window7_224(pretrained=False, **kwargs) -> SwinTransformer: + """ Swin-B @ 224x224 + """ + model_args = dict(patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32)) + return _create_swin_transformer( + 'swin_base_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swin_base_patch4_window12_384(pretrained=False, **kwargs) -> SwinTransformer: + """ Swin-B @ 384x384 + """ + model_args = dict(patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32)) + return _create_swin_transformer( + 'swin_base_patch4_window12_384', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swin_large_patch4_window7_224(pretrained=False, **kwargs) -> SwinTransformer: + """ Swin-L @ 224x224 + """ + model_args = dict(patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48)) + return _create_swin_transformer( + 'swin_large_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swin_large_patch4_window12_384(pretrained=False, **kwargs) -> SwinTransformer: + """ Swin-L @ 384x384 + """ + model_args = dict(patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48)) + return _create_swin_transformer( + 'swin_large_patch4_window12_384', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swin_s3_tiny_224(pretrained=False, **kwargs) -> SwinTransformer: + """ Swin-S3-T @ 224x224, https://arxiv.org/abs/2111.14725 + """ + model_args = dict( + patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24)) + return _create_swin_transformer('swin_s3_tiny_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swin_s3_small_224(pretrained=False, **kwargs) -> SwinTransformer: + """ Swin-S3-S @ 224x224, https://arxiv.org/abs/2111.14725 + """ + model_args = dict( + patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24)) + return _create_swin_transformer('swin_s3_small_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swin_s3_base_224(pretrained=False, **kwargs) -> SwinTransformer: + """ Swin-S3-B @ 224x224, https://arxiv.org/abs/2111.14725 + """ + model_args = dict( + patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2), num_heads=(3, 6, 12, 24)) + return _create_swin_transformer('swin_s3_base_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +register_model_deprecations(__name__, { + 'swin_base_patch4_window7_224_in22k': 'swin_base_patch4_window7_224.ms_in22k', + 'swin_base_patch4_window12_384_in22k': 'swin_base_patch4_window12_384.ms_in22k', + 'swin_large_patch4_window7_224_in22k': 'swin_large_patch4_window7_224.ms_in22k', + 'swin_large_patch4_window12_384_in22k': 'swin_large_patch4_window12_384.ms_in22k', +}) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/swin_transformer_v2.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/swin_transformer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..b152b5447002cccb321f46b949a6622dfcb88126 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/swin_transformer_v2.py @@ -0,0 +1,853 @@ +""" Swin Transformer V2 +A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` + - https://arxiv.org/abs/2111.09883 + +Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below + +Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman +""" +# -------------------------------------------------------- +# Swin Transformer V2 +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- +import math +from typing import Callable, Optional, Tuple, Union, Set, Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\ + resample_patch_embed, ndgrid +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._registry import generate_default_cfgs, register_model, register_model_deprecations + +__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this + +_int_or_tuple_2_t = Union[int, Tuple[int, int]] + + +def window_partition(x: torch.Tensor, window_size: Tuple[int, int]) -> torch.Tensor: + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], img_size: Tuple[int, int]) -> torch.Tensor: + """ + Args: + windows: (num_windows * B, window_size[0], window_size[1], C) + window_size (Tuple[int, int]): Window size + img_size (Tuple[int, int]): Image size + + Returns: + x: (B, H, W, C) + """ + H, W = img_size + C = windows.shape[-1] + x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__( + self, + dim: int, + window_size: Tuple[int, int], + num_heads: int, + qkv_bias: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + pretrained_window_size: Tuple[int, int] = (0, 0), + ) -> None: + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False) + ) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0]).to(torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1]).to(torch.float32) + relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w)) + relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / math.log2(8) + + self.register_buffer("relative_coords_table", relative_coords_table, persistent=False) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(ndgrid(coords_h, coords_w)) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index, persistent=False) + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.register_buffer('k_bias', torch.zeros(dim), persistent=False) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.k_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + # cosine attention + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp() + attn = attn * logit_scale + + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + num_win = mask.shape[0] + attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerV2Block(nn.Module): + """ Swin Transformer Block. + """ + + def __init__( + self, + dim: int, + input_resolution: _int_or_tuple_2_t, + num_heads: int, + window_size: _int_or_tuple_2_t = 7, + shift_size: _int_or_tuple_2_t = 0, + mlp_ratio: float = 4., + qkv_bias: bool = True, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + pretrained_window_size: _int_or_tuple_2_t = 0, + ) -> None: + """ + Args: + dim: Number of input channels. + input_resolution: Input resolution. + num_heads: Number of attention heads. + window_size: Window size. + shift_size: Shift size for SW-MSA. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + proj_drop: Dropout rate. + attn_drop: Attention dropout rate. + drop_path: Stochastic depth rate. + act_layer: Activation layer. + norm_layer: Normalization layer. + pretrained_window_size: Window size in pretraining. + """ + super().__init__() + self.dim = dim + self.input_resolution = to_2tuple(input_resolution) + self.num_heads = num_heads + ws, ss = self._calc_window_shift(window_size, shift_size) + self.window_size: Tuple[int, int] = ws + self.shift_size: Tuple[int, int] = ss + self.window_area = self.window_size[0] * self.window_size[1] + self.mlp_ratio = mlp_ratio + + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + pretrained_window_size=to_2tuple(pretrained_window_size), + ) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + if any(self.shift_size): + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + cnt = 0 + for h in ( + slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)): + for w in ( + slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)): + img_mask[:, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_area) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask, persistent=False) + + def _calc_window_shift(self, + target_window_size: _int_or_tuple_2_t, + target_shift_size: _int_or_tuple_2_t) -> Tuple[Tuple[int, int], Tuple[int, int]]: + target_window_size = to_2tuple(target_window_size) + target_shift_size = to_2tuple(target_shift_size) + window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] + shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] + return tuple(window_size), tuple(shift_size) + + def _attn(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, C = x.shape + + # cyclic shift + has_shift = any(self.shift_size) + if has_shift: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) + shifted_x = window_reverse(attn_windows, self.window_size, self.input_resolution) # B H' W' C + + # reverse cyclic shift + if has_shift: + x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2)) + else: + x = shifted_x + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, C = x.shape + x = x + self.drop_path1(self.norm1(self._attn(x))) + x = x.reshape(B, -1, C) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + x = x.reshape(B, H, W, C) + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer. + """ + + def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: nn.Module = nn.LayerNorm) -> None: + """ + Args: + dim (int): Number of input channels. + out_dim (int): Number of output channels (or 2 * dim if None) + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + super().__init__() + self.dim = dim + self.out_dim = out_dim or 2 * dim + self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False) + self.norm = norm_layer(self.out_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, C = x.shape + _assert(H % 2 == 0, f"x height ({H}) is not even.") + _assert(W % 2 == 0, f"x width ({W}) is not even.") + x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3) + x = self.reduction(x) + x = self.norm(x) + return x + + +class SwinTransformerV2Stage(nn.Module): + """ A Swin Transformer V2 Stage. + """ + + def __init__( + self, + dim: int, + out_dim: int, + input_resolution: _int_or_tuple_2_t, + depth: int, + num_heads: int, + window_size: _int_or_tuple_2_t, + downsample: bool = False, + mlp_ratio: float = 4., + qkv_bias: bool = True, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + norm_layer: nn.Module = nn.LayerNorm, + pretrained_window_size: _int_or_tuple_2_t = 0, + output_nchw: bool = False, + ) -> None: + """ + Args: + dim: Number of input channels. + out_dim: Number of output channels. + input_resolution: Input resolution. + depth: Number of blocks. + num_heads: Number of attention heads. + window_size: Local window size. + downsample: Use downsample layer at start of the block. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + proj_drop: Projection dropout rate + attn_drop: Attention dropout rate. + drop_path: Stochastic depth rate. + norm_layer: Normalization layer. + pretrained_window_size: Local window size in pretraining. + output_nchw: Output tensors on NCHW format instead of NHWC. + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution + self.depth = depth + self.output_nchw = output_nchw + self.grad_checkpointing = False + window_size = to_2tuple(window_size) + shift_size = tuple([w // 2 for w in window_size]) + + # patch merging / downsample layer + if downsample: + self.downsample = PatchMerging(dim=dim, out_dim=out_dim, norm_layer=norm_layer) + else: + assert dim == out_dim + self.downsample = nn.Identity() + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerV2Block( + dim=out_dim, + input_resolution=self.output_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop=proj_drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size, + ) + for i in range(depth)]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.downsample(x) + + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + return x + + def _init_respostnorm(self) -> None: + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + + +class SwinTransformerV2(nn.Module): + """ Swin Transformer V2 + + A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` + - https://arxiv.org/abs/2111.09883 + """ + + def __init__( + self, + img_size: _int_or_tuple_2_t = 224, + patch_size: int = 4, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + embed_dim: int = 96, + depths: Tuple[int, ...] = (2, 2, 6, 2), + num_heads: Tuple[int, ...] = (3, 6, 12, 24), + window_size: _int_or_tuple_2_t = 7, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0.1, + norm_layer: Callable = nn.LayerNorm, + pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0), + **kwargs, + ): + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of input image channels. + num_classes: Number of classes for classification head. + embed_dim: Patch embedding dimension. + depths: Depth of each Swin Transformer stage (layer). + num_heads: Number of attention heads in different layers. + window_size: Window size. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + drop_rate: Head dropout rate. + proj_drop_rate: Projection dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + norm_layer: Normalization layer. + patch_norm: If True, add normalization after patch embedding. + pretrained_window_sizes: Pretrained window sizes of each layer. + output_fmt: Output tensor format if not None, otherwise output 'NHWC' by default. + """ + super().__init__() + + self.num_classes = num_classes + assert global_pool in ('', 'avg') + self.global_pool = global_pool + self.output_fmt = 'NHWC' + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.feature_info = [] + + if not isinstance(embed_dim, (tuple, list)): + embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim[0], + norm_layer=norm_layer, + output_fmt='NHWC', + ) + + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + layers = [] + in_dim = embed_dim[0] + scale = 1 + for i in range(self.num_layers): + out_dim = embed_dim[i] + layers += [SwinTransformerV2Stage( + dim=in_dim, + out_dim=out_dim, + input_resolution=( + self.patch_embed.grid_size[0] // scale, + self.patch_embed.grid_size[1] // scale), + depth=depths[i], + downsample=i > 0, + num_heads=num_heads[i], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_sizes[i], + )] + in_dim = out_dim + if i > 0: + scale *= 2 + self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')] + + self.layers = nn.Sequential(*layers) + self.norm = norm_layer(self.num_features) + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + input_fmt=self.output_fmt, + ) + + self.apply(self._init_weights) + for bly in self.layers: + bly._init_respostnorm() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + nod = set() + for n, m in self.named_modules(): + if any([kw in n for kw in ("cpb_mlp", "logit_scale")]): + nod.add(n) + return nod + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^absolute_pos_embed|patch_embed', # stem and embed + blocks=r'^layers\.(\d+)' if coarse else [ + (r'^layers\.(\d+).downsample', (0,)), + (r'^layers\.(\d+)\.\w+\.(\d+)', None), + (r'^norm', (99999,)), + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for l in self.layers: + l.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) + + def forward_features(self, x): + x = self.patch_embed(x) + x = self.layers(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=True) if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + state_dict = state_dict.get('model', state_dict) + state_dict = state_dict.get('state_dict', state_dict) + native_checkpoint = 'head.fc.weight' in state_dict + out_dict = {} + import re + for k, v in state_dict.items(): + if any([n in k for n in ('relative_position_index', 'relative_coords_table', 'attn_mask')]): + continue # skip buffers that should not be persistent + + if 'patch_embed.proj.weight' in k: + _, _, H, W = model.patch_embed.proj.weight.shape + if v.shape[-2] != H or v.shape[-1] != W: + v = resample_patch_embed( + v, + (H, W), + interpolation='bicubic', + antialias=True, + verbose=True, + ) + + if not native_checkpoint: + # skip layer remapping for updated checkpoints + k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) + k = k.replace('head.', 'head.fc.') + out_dict[k] = v + + return out_dict + + +def _create_swin_transformer_v2(variant, pretrained=False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 1, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + + model = build_model_with_cfg( + SwinTransformerV2, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs) + return model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', + 'license': 'mit', **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'swinv2_base_window12to16_192to256.ms_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth', + ), + 'swinv2_base_window12to24_192to384.ms_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + 'swinv2_large_window12to16_192to256.ms_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth', + ), + 'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + + 'swinv2_tiny_window8_256.ms_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth', + ), + 'swinv2_tiny_window16_256.ms_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth', + ), + 'swinv2_small_window8_256.ms_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth', + ), + 'swinv2_small_window16_256.ms_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth', + ), + 'swinv2_base_window8_256.ms_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth', + ), + 'swinv2_base_window16_256.ms_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth', + ), + + 'swinv2_base_window12_192.ms_in22k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth', + num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6) + ), + 'swinv2_large_window12_192.ms_in22k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth', + num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6) + ), +}) + + +@register_model +def swinv2_tiny_window16_256(pretrained=False, **kwargs) -> SwinTransformerV2: + """ + """ + model_args = dict(window_size=16, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24)) + return _create_swin_transformer_v2( + 'swinv2_tiny_window16_256', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swinv2_tiny_window8_256(pretrained=False, **kwargs) -> SwinTransformerV2: + """ + """ + model_args = dict(window_size=8, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24)) + return _create_swin_transformer_v2( + 'swinv2_tiny_window8_256', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swinv2_small_window16_256(pretrained=False, **kwargs) -> SwinTransformerV2: + """ + """ + model_args = dict(window_size=16, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24)) + return _create_swin_transformer_v2( + 'swinv2_small_window16_256', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swinv2_small_window8_256(pretrained=False, **kwargs) -> SwinTransformerV2: + """ + """ + model_args = dict(window_size=8, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24)) + return _create_swin_transformer_v2( + 'swinv2_small_window8_256', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swinv2_base_window16_256(pretrained=False, **kwargs) -> SwinTransformerV2: + """ + """ + model_args = dict(window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32)) + return _create_swin_transformer_v2( + 'swinv2_base_window16_256', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swinv2_base_window8_256(pretrained=False, **kwargs) -> SwinTransformerV2: + """ + """ + model_args = dict(window_size=8, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32)) + return _create_swin_transformer_v2( + 'swinv2_base_window8_256', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swinv2_base_window12_192(pretrained=False, **kwargs) -> SwinTransformerV2: + """ + """ + model_args = dict(window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32)) + return _create_swin_transformer_v2( + 'swinv2_base_window12_192', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swinv2_base_window12to16_192to256(pretrained=False, **kwargs) -> SwinTransformerV2: + """ + """ + model_args = dict( + window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), + pretrained_window_sizes=(12, 12, 12, 6)) + return _create_swin_transformer_v2( + 'swinv2_base_window12to16_192to256', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swinv2_base_window12to24_192to384(pretrained=False, **kwargs) -> SwinTransformerV2: + """ + """ + model_args = dict( + window_size=24, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), + pretrained_window_sizes=(12, 12, 12, 6)) + return _create_swin_transformer_v2( + 'swinv2_base_window12to24_192to384', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swinv2_large_window12_192(pretrained=False, **kwargs) -> SwinTransformerV2: + """ + """ + model_args = dict(window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48)) + return _create_swin_transformer_v2( + 'swinv2_large_window12_192', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swinv2_large_window12to16_192to256(pretrained=False, **kwargs) -> SwinTransformerV2: + """ + """ + model_args = dict( + window_size=16, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), + pretrained_window_sizes=(12, 12, 12, 6)) + return _create_swin_transformer_v2( + 'swinv2_large_window12to16_192to256', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swinv2_large_window12to24_192to384(pretrained=False, **kwargs) -> SwinTransformerV2: + """ + """ + model_args = dict( + window_size=24, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), + pretrained_window_sizes=(12, 12, 12, 6)) + return _create_swin_transformer_v2( + 'swinv2_large_window12to24_192to384', pretrained=pretrained, **dict(model_args, **kwargs)) + + +register_model_deprecations(__name__, { + 'swinv2_base_window12_192_22k': 'swinv2_base_window12_192.ms_in22k', + 'swinv2_base_window12to16_192to256_22kft1k': 'swinv2_base_window12to16_192to256.ms_in22k_ft_in1k', + 'swinv2_base_window12to24_192to384_22kft1k': 'swinv2_base_window12to24_192to384.ms_in22k_ft_in1k', + 'swinv2_large_window12_192_22k': 'swinv2_large_window12_192.ms_in22k', + 'swinv2_large_window12to16_192to256_22kft1k': 'swinv2_large_window12to16_192to256.ms_in22k_ft_in1k', + 'swinv2_large_window12to24_192to384_22kft1k': 'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k', +}) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/tiny_vit.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/tiny_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..b4b2964810e5490db149c0fcb6393baaeb933211 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/tiny_vit.py @@ -0,0 +1,716 @@ +""" TinyViT + +Paper: `TinyViT: Fast Pretraining Distillation for Small Vision Transformers` + - https://arxiv.org/abs/2207.10666 + +Adapted from official impl at https://github.com/microsoft/Cream/tree/main/TinyViT +""" + +__all__ = ['TinyVit'] + +import math +import itertools +from functools import partial +from typing import Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\ + trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs + + +class ConvNorm(torch.nn.Sequential): + def __init__(self, in_chs, out_chs, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): + super().__init__() + self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False) + self.bn = nn.BatchNorm2d(out_chs) + torch.nn.init.constant_(self.bn.weight, bn_weight_init) + torch.nn.init.constant_(self.bn.bias, 0) + + @torch.no_grad() + def fuse(self): + c, bn = self.conv, self.bn + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps) ** 0.5 + m = torch.nn.Conv2d( + w.size(1) * self.conv.groups, w.size(0), w.shape[2:], + stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class PatchEmbed(nn.Module): + def __init__(self, in_chs, out_chs, act_layer): + super().__init__() + self.stride = 4 + self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1) + self.act = act_layer() + self.conv2 = ConvNorm(out_chs // 2, out_chs, 3, 2, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.act(x) + x = self.conv2(x) + return x + + +class MBConv(nn.Module): + def __init__(self, in_chs, out_chs, expand_ratio, act_layer, drop_path): + super().__init__() + mid_chs = int(in_chs * expand_ratio) + self.conv1 = ConvNorm(in_chs, mid_chs, ks=1) + self.act1 = act_layer() + self.conv2 = ConvNorm(mid_chs, mid_chs, ks=3, stride=1, pad=1, groups=mid_chs) + self.act2 = act_layer() + self.conv3 = ConvNorm(mid_chs, out_chs, ks=1, bn_weight_init=0.0) + self.act3 = act_layer() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.act2(x) + x = self.conv3(x) + x = self.drop_path(x) + x += shortcut + x = self.act3(x) + return x + + +class PatchMerging(nn.Module): + def __init__(self, dim, out_dim, act_layer): + super().__init__() + self.conv1 = ConvNorm(dim, out_dim, 1, 1, 0) + self.act1 = act_layer() + self.conv2 = ConvNorm(out_dim, out_dim, 3, 2, 1, groups=out_dim) + self.act2 = act_layer() + self.conv3 = ConvNorm(out_dim, out_dim, 1, 1, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.act2(x) + x = self.conv3(x) + return x + + +class ConvLayer(nn.Module): + def __init__( + self, + dim, + depth, + act_layer, + drop_path=0., + conv_expand_ratio=4., + ): + super().__init__() + self.dim = dim + self.depth = depth + self.blocks = nn.Sequential(*[ + MBConv( + dim, dim, conv_expand_ratio, act_layer, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth) + ]) + + def forward(self, x): + x = self.blocks(x) + return x + + +class NormMlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + drop=0., + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.norm = norm_layer(in_features) + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.drop1 = nn.Dropout(drop) + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop2 = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class Attention(torch.nn.Module): + fused_attn: torch.jit.Final[bool] + attention_bias_cache: Dict[str, torch.Tensor] + + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=(14, 14), + ): + super().__init__() + assert isinstance(resolution, tuple) and len(resolution) == 2 + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.val_dim = int(attn_ratio * key_dim) + self.out_dim = self.val_dim * num_heads + self.attn_ratio = attn_ratio + self.resolution = resolution + self.fused_attn = use_fused_attn() + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, num_heads * (self.val_dim + 2 * key_dim)) + self.proj = nn.Linear(self.out_dim, dim) + + points = list(itertools.product(range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False) + self.attention_bias_cache = {} + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.attention_bias_cache: + self.attention_bias_cache = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if torch.jit.is_tracing() or self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.attention_bias_cache[device_key] + + def forward(self, x): + attn_bias = self.get_attention_biases(x.device) + B, N, _ = x.shape + # Normalization + x = self.norm(x) + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn + attn_bias + attn = attn.softmax(dim=-1) + x = attn @ v + x = x.transpose(1, 2).reshape(B, N, self.out_dim) + x = self.proj(x) + return x + + +class TinyVitBlock(nn.Module): + """ TinyViT Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + local_conv_size (int): the kernel size of the convolution between + Attention and MLP. Default: 3 + act_layer: the activation function. Default: nn.GELU + """ + + def __init__( + self, + dim, + num_heads, + window_size=7, + mlp_ratio=4., + drop=0., + drop_path=0., + local_conv_size=3, + act_layer=nn.GELU + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + assert window_size > 0, 'window_size must be greater than 0' + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + assert dim % num_heads == 0, 'dim must be divisible by num_heads' + head_dim = dim // num_heads + + window_resolution = (window_size, window_size) + self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + + self.mlp = NormMlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + pad = local_conv_size // 2 + self.local_conv = ConvNorm(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) + + def forward(self, x): + B, H, W, C = x.shape + L = H * W + + shortcut = x + if H == self.window_size and W == self.window_size: + x = x.reshape(B, L, C) + x = self.attn(x) + x = x.view(B, H, W, C) + else: + pad_b = (self.window_size - H % self.window_size) % self.window_size + pad_r = (self.window_size - W % self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + # window partition + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape( + B * nH * nW, self.window_size * self.window_size, C + ) + + x = self.attn(x) + + # window reverse + x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C) + + if padding: + x = x[:, :H, :W].contiguous() + x = shortcut + self.drop_path1(x) + + x = x.permute(0, 3, 1, 2) + x = self.local_conv(x) + x = x.reshape(B, C, L).transpose(1, 2) + + x = x + self.drop_path2(self.mlp(x)) + return x.view(B, H, W, C) + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + + +register_notrace_module(TinyVitBlock) + + +class TinyVitStage(nn.Module): + """ A basic TinyViT layer for one stage. + + Args: + dim (int): Number of input channels. + out_dim: the output dimension of the layer + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3 + act_layer: the activation function. Default: nn.GELU + """ + + def __init__( + self, + dim, + out_dim, + depth, + num_heads, + window_size, + mlp_ratio=4., + drop=0., + drop_path=0., + downsample=None, + local_conv_size=3, + act_layer=nn.GELU, + ): + + super().__init__() + self.depth = depth + self.out_dim = out_dim + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + dim=dim, + out_dim=out_dim, + act_layer=act_layer, + ) + else: + self.downsample = nn.Identity() + assert dim == out_dim + + # build blocks + self.blocks = nn.Sequential(*[ + TinyVitBlock( + dim=out_dim, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + local_conv_size=local_conv_size, + act_layer=act_layer, + ) + for i in range(depth)]) + + def forward(self, x): + x = self.downsample(x) + x = x.permute(0, 2, 3, 1) # BCHW -> BHWC + x = self.blocks(x) + x = x.permute(0, 3, 1, 2) # BHWC -> BCHW + return x + + def extra_repr(self) -> str: + return f"dim={self.out_dim}, depth={self.depth}" + + +class TinyVit(nn.Module): + def __init__( + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + embed_dims=(96, 192, 384, 768), + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + window_sizes=(7, 7, 14, 7), + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + act_layer=nn.GELU, + ): + super().__init__() + + self.num_classes = num_classes + self.depths = depths + self.num_stages = len(depths) + self.mlp_ratio = mlp_ratio + self.grad_checkpointing = use_checkpoint + + self.patch_embed = PatchEmbed( + in_chs=in_chans, + out_chs=embed_dims[0], + act_layer=act_layer, + ) + + # stochastic depth rate rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + # build stages + self.stages = nn.Sequential() + stride = self.patch_embed.stride + prev_dim = embed_dims[0] + self.feature_info = [] + for stage_idx in range(self.num_stages): + if stage_idx == 0: + stage = ConvLayer( + dim=prev_dim, + depth=depths[stage_idx], + act_layer=act_layer, + drop_path=dpr[:depths[stage_idx]], + conv_expand_ratio=mbconv_expand_ratio, + ) + else: + out_dim = embed_dims[stage_idx] + drop_path_rate = dpr[sum(depths[:stage_idx]):sum(depths[:stage_idx + 1])] + stage = TinyVitStage( + dim=embed_dims[stage_idx - 1], + out_dim=out_dim, + depth=depths[stage_idx], + num_heads=num_heads[stage_idx], + window_size=window_sizes[stage_idx], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + drop_path=drop_path_rate, + downsample=PatchMerging, + act_layer=act_layer, + ) + prev_dim = out_dim + stride *= 2 + self.stages.append(stage) + self.feature_info += [dict(num_chs=prev_dim, reduction=stride, module=f'stages.{stage_idx}')] + + # Classifier head + self.num_features = embed_dims[-1] + + norm_layer_cf = partial(LayerNorm2d, eps=1e-5) + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + norm_layer=norm_layer_cf, + ) + + # init weights + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'attention_biases'} + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^patch_embed', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.\w+\.(\d+)', None), + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + self.head.reset(num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.patch_embed(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + return x + + def forward_head(self, x): + x = self.head(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + if 'model' in state_dict.keys(): + state_dict = state_dict['model'] + target_sd = model.state_dict() + out_dict = {} + for k, v in state_dict.items(): + if k.endswith('attention_bias_idxs'): + continue + if 'attention_biases' in k: + # TODO: whether move this func into model for dynamic input resolution? (high risk) + v = resize_rel_pos_bias_table_levit(v.T, target_sd[k].shape[::-1]).T + out_dict[k] = v + return out_dict + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.conv1.conv', + 'classifier': 'head.fc', + 'pool_size': (7, 7), + 'input_size': (3, 224, 224), + 'crop_pct': 0.95, + **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + 'tiny_vit_5m_224.dist_in22k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22k_distill.pth', + num_classes=21841 + ), + 'tiny_vit_5m_224.dist_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22kto1k_distill.pth' + ), + 'tiny_vit_5m_224.in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_1k.pth' + ), + 'tiny_vit_11m_224.dist_in22k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22k_distill.pth', + num_classes=21841 + ), + 'tiny_vit_11m_224.dist_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22kto1k_distill.pth' + ), + 'tiny_vit_11m_224.in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_1k.pth' + ), + 'tiny_vit_21m_224.dist_in22k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22k_distill.pth', + num_classes=21841 + ), + 'tiny_vit_21m_224.dist_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_distill.pth' + ), + 'tiny_vit_21m_224.in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_1k.pth' + ), + 'tiny_vit_21m_384.dist_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_384_distill.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + 'tiny_vit_21m_512.dist_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_512_distill.pth', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash', + ), +}) + + +def _create_tiny_vit(variant, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) + model = build_model_with_cfg( + TinyVit, + variant, + pretrained, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs + ) + return model + + +@register_model +def tiny_vit_5m_224(pretrained=False, **kwargs): + model_kwargs = dict( + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + drop_path_rate=0.0, + ) + model_kwargs.update(kwargs) + return _create_tiny_vit('tiny_vit_5m_224', pretrained, **model_kwargs) + + +@register_model +def tiny_vit_11m_224(pretrained=False, **kwargs): + model_kwargs = dict( + embed_dims=[64, 128, 256, 448], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 8, 14], + window_sizes=[7, 7, 14, 7], + drop_path_rate=0.1, + ) + model_kwargs.update(kwargs) + return _create_tiny_vit('tiny_vit_11m_224', pretrained, **model_kwargs) + + +@register_model +def tiny_vit_21m_224(pretrained=False, **kwargs): + model_kwargs = dict( + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[7, 7, 14, 7], + drop_path_rate=0.2, + ) + model_kwargs.update(kwargs) + return _create_tiny_vit('tiny_vit_21m_224', pretrained, **model_kwargs) + + +@register_model +def tiny_vit_21m_384(pretrained=False, **kwargs): + model_kwargs = dict( + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[12, 12, 24, 12], + drop_path_rate=0.1, + ) + model_kwargs.update(kwargs) + return _create_tiny_vit('tiny_vit_21m_384', pretrained, **model_kwargs) + + +@register_model +def tiny_vit_21m_512(pretrained=False, **kwargs): + model_kwargs = dict( + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[16, 16, 32, 16], + drop_path_rate=0.1, + ) + model_kwargs.update(kwargs) + return _create_tiny_vit('tiny_vit_21m_512', pretrained, **model_kwargs) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/twins.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/twins.py new file mode 100644 index 0000000000000000000000000000000000000000..3cd25fb43389b389d59e66e68ca3fc0586d3c2d6 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/twins.py @@ -0,0 +1,505 @@ +""" Twins +A PyTorch impl of : `Twins: Revisiting the Design of Spatial Attention in Vision Transformers` + - https://arxiv.org/pdf/2104.13840.pdf + +Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/license info below + +""" +# -------------------------------------------------------- +# Twins +# Copyright (c) 2021 Meituan +# Licensed under The Apache 2.0 License [see LICENSE for details] +# Written by Xinjie Li, Xiangxiang Chu +# -------------------------------------------------------- +import math +from functools import partial +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_, use_fused_attn +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._registry import register_model, generate_default_cfgs +from .vision_transformer import Attention + +__all__ = ['Twins'] # model_registry will add each entrypoint fn to this + +Size_ = Tuple[int, int] + + +@register_notrace_module # reason: FX can't symbolically trace control flow in forward method +class LocallyGroupedAttn(nn.Module): + """ LSA: self attention within a group + """ + fused_attn: torch.jit.Final[bool] + + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1): + assert ws != 1 + super(LocallyGroupedAttn, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + qkv = self.qkv(x).reshape( + B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + q, k, v = qkv.unbind(0) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = x.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + # def forward_mask(self, x, size: Size_): + # B, N, C = x.shape + # H, W = size + # x = x.view(B, H, W, C) + # pad_l = pad_t = 0 + # pad_r = (self.ws - W % self.ws) % self.ws + # pad_b = (self.ws - H % self.ws) % self.ws + # x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + # _, Hp, Wp, _ = x.shape + # _h, _w = Hp // self.ws, Wp // self.ws + # mask = torch.zeros((1, Hp, Wp), device=x.device) + # mask[:, -pad_b:, :].fill_(1) + # mask[:, :, -pad_r:].fill_(1) + # + # x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) # B, _h, _w, ws, ws, C + # mask = mask.reshape(1, _h, self.ws, _w, self.ws).transpose(2, 3).reshape(1, _h * _w, self.ws * self.ws) + # attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) # 1, _h*_w, ws*ws, ws*ws + # attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-1000.0)).masked_fill(attn_mask == 0, float(0.0)) + # qkv = self.qkv(x).reshape( + # B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + # # n_h, B, _w*_h, nhead, ws*ws, dim + # q, k, v = qkv[0], qkv[1], qkv[2] # B, _h*_w, n_head, ws*ws, dim_head + # attn = (q @ k.transpose(-2, -1)) * self.scale # B, _h*_w, n_head, ws*ws, ws*ws + # attn = attn + attn_mask.unsqueeze(2) + # attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) # attn @v -> B, _h*_w, n_head, ws*ws, dim_head + # attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + # x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + # if pad_r > 0 or pad_b > 0: + # x = x[:, :H, :W, :].contiguous() + # x = x.reshape(B, N, C) + # x = self.proj(x) + # x = self.proj_drop(x) + # return x + + +class GlobalSubSampleAttn(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + fused_attn: torch.jit.Final[bool] + + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.q = nn.Linear(dim, dim, bias=True) + self.kv = nn.Linear(dim, dim * 2, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr is not None: + x = x.permute(0, 2, 1).reshape(B, C, *size) + x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) + + if self.fused_attn: + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + sr_ratio=1, + ws=None, + ): + super().__init__() + self.norm1 = norm_layer(dim) + if ws is None: + self.attn = Attention(dim, num_heads, False, None, attn_drop, proj_drop) + elif ws == 1: + self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, proj_drop, sr_ratio) + else: + self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, proj_drop, ws) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x, size: Size_): + x = x + self.drop_path1(self.attn(self.norm1(x), size)) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + return x + + +class PosConv(nn.Module): + # PEG from https://arxiv.org/abs/2102.10882 + def __init__(self, in_chans, embed_dim=768, stride=1): + super(PosConv, self).__init__() + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), + ) + self.stride = stride + + def forward(self, x, size: Size_): + B, N, C = x.shape + cnn_feat_token = x.transpose(1, 2).view(B, C, *size) + x = self.proj(cnn_feat_token) + if self.stride == 1: + x += cnn_feat_token + x = x.flatten(2).transpose(1, 2) + return x + + def no_weight_decay(self): + return ['proj.%d.weight' % i for i in range(4)] + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ + f"img_size {img_size} should be divided by patch_size {patch_size}." + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x) -> Tuple[torch.Tensor, Size_]: + B, C, H, W = x.shape + + x = self.proj(x).flatten(2).transpose(1, 2) + x = self.norm(x) + out_size = (H // self.patch_size[0], W // self.patch_size[1]) + + return x, out_size + + +class Twins(nn.Module): + """ Twins Vision Transfomer (Revisiting Spatial Attention) + + Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git + """ + def __init__( + self, + img_size=224, + patch_size=4, + in_chans=3, + num_classes=1000, + global_pool='avg', + embed_dims=(64, 128, 256, 512), + num_heads=(1, 2, 4, 8), + mlp_ratios=(4, 4, 4, 4), + depths=(3, 4, 6, 3), + sr_ratios=(8, 4, 2, 1), + wss=None, + drop_rate=0., + pos_drop_rate=0., + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), + block_cls=Block, + ): + super().__init__() + self.num_classes = num_classes + self.global_pool = global_pool + self.depths = depths + self.embed_dims = embed_dims + self.num_features = embed_dims[-1] + self.grad_checkpointing = False + + img_size = to_2tuple(img_size) + prev_chs = in_chans + self.patch_embeds = nn.ModuleList() + self.pos_drops = nn.ModuleList() + for i in range(len(depths)): + self.patch_embeds.append(PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i])) + self.pos_drops.append(nn.Dropout(p=pos_drop_rate)) + prev_chs = embed_dims[i] + img_size = tuple(t // patch_size for t in img_size) + patch_size = 2 + + self.blocks = nn.ModuleList() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + for k in range(len(depths)): + _block = nn.ModuleList([block_cls( + dim=embed_dims[k], + num_heads=num_heads[k], + mlp_ratio=mlp_ratios[k], + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[k], + ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])], + ) + self.blocks.append(_block) + cur += depths[k] + + self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims]) + + self.norm = norm_layer(self.num_features) + + # classification head + self.head_drop = nn.Dropout(drop_rate) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + # init weights + self.apply(self._init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return set(['pos_block.' + n for n, p in self.pos_block.named_parameters()]) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^patch_embeds.0', # stem and embed + blocks=[ + (r'^(?:blocks|patch_embeds|pos_block)\.(\d+)', None), + ('^norm', (99999,)) + ] if coarse else [ + (r'^blocks\.(\d+)\.(\d+)', None), + (r'^(?:patch_embeds|pos_block)\.(\d+)', (0,)), + (r'^norm', (99999,)) + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg') + self.global_pool = global_pool + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward_features(self, x): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) # PEG here + if i < len(self.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=1) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_twins(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg(Twins, variant, pretrained, **kwargs) + return model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'twins_pcpvt_small.in1k': _cfg(hf_hub_id='timm/'), + 'twins_pcpvt_base.in1k': _cfg(hf_hub_id='timm/'), + 'twins_pcpvt_large.in1k': _cfg(hf_hub_id='timm/'), + 'twins_svt_small.in1k': _cfg(hf_hub_id='timm/'), + 'twins_svt_base.in1k': _cfg(hf_hub_id='timm/'), + 'twins_svt_large.in1k': _cfg(hf_hub_id='timm/'), +}) + + +@register_model +def twins_pcpvt_small(pretrained=False, **kwargs) -> Twins: + model_args = dict( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]) + return _create_twins('twins_pcpvt_small', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def twins_pcpvt_base(pretrained=False, **kwargs) -> Twins: + model_args = dict( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1]) + return _create_twins('twins_pcpvt_base', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def twins_pcpvt_large(pretrained=False, **kwargs) -> Twins: + model_args = dict( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1]) + return _create_twins('twins_pcpvt_large', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def twins_svt_small(pretrained=False, **kwargs) -> Twins: + model_args = dict( + patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4], + depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1]) + return _create_twins('twins_svt_small', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def twins_svt_base(pretrained=False, **kwargs) -> Twins: + model_args = dict( + patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4], + depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1]) + return _create_twins('twins_svt_base', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def twins_svt_large(pretrained=False, **kwargs) -> Twins: + model_args = dict( + patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], + depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1]) + return _create_twins('twins_svt_large', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/visformer.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/visformer.py new file mode 100644 index 0000000000000000000000000000000000000000..953fc64d5e2440de642df9677db4a76c00c3023d --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/visformer.py @@ -0,0 +1,549 @@ +""" Visformer + +Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.12533 + +From original at https://github.com/danczs/Visformer + +Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +""" + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier, use_fused_attn +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs + +__all__ = ['Visformer'] + + +class SpatialMlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0., + group=8, + spatial_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + drop_probs = to_2tuple(drop) + + self.in_features = in_features + self.out_features = out_features + self.spatial_conv = spatial_conv + if self.spatial_conv: + if group < 2: # net setting + hidden_features = in_features * 5 // 6 + else: + hidden_features = in_features * 2 + self.hidden_features = hidden_features + self.group = group + self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False) + self.act1 = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + if self.spatial_conv: + self.conv2 = nn.Conv2d( + hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False) + self.act2 = act_layer() + else: + self.conv2 = None + self.act2 = None + self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False) + self.drop3 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.conv1(x) + x = self.act1(x) + x = self.drop1(x) + if self.conv2 is not None: + x = self.conv2(x) + x = self.act2(x) + x = self.conv3(x) + x = self.drop3(x) + return x + + +class Attention(nn.Module): + fused_attn: torch.jit.Final[bool] + + def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.): + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = round(dim // num_heads * head_dim_ratio) + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn(experimental=True) + + self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, C, H, W = x.shape + x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3) + q, k, v = x.unbind(0) + + if self.fused_attn: + x = torch.nn.functional.scaled_dot_product_attention( + q.contiguous(), k.contiguous(), v.contiguous(), + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + head_dim_ratio=1., + mlp_ratio=4., + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=LayerNorm2d, + group=8, + attn_disabled=False, + spatial_conv=False, + ): + super().__init__() + self.spatial_conv = spatial_conv + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + if attn_disabled: + self.norm1 = None + self.attn = None + else: + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + head_dim_ratio=head_dim_ratio, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) + + self.norm2 = norm_layer(dim) + self.mlp = SpatialMlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + group=group, + spatial_conv=spatial_conv, + ) + + def forward(self, x): + if self.attn is not None: + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Visformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + init_channels=32, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4., + drop_rate=0., + pos_drop_rate=0., + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=LayerNorm2d, + attn_stage='111', + use_pos_embed=True, + spatial_conv='111', + vit_stem=False, + group=8, + global_pool='avg', + conv_init=False, + embed_norm=None, + ): + super().__init__() + img_size = to_2tuple(img_size) + self.num_classes = num_classes + self.embed_dim = embed_dim + self.init_channels = init_channels + self.img_size = img_size + self.vit_stem = vit_stem + self.conv_init = conv_init + if isinstance(depth, (list, tuple)): + self.stage_num1, self.stage_num2, self.stage_num3 = depth + depth = sum(depth) + else: + self.stage_num1 = self.stage_num3 = depth // 3 + self.stage_num2 = depth - self.stage_num1 - self.stage_num3 + self.use_pos_embed = use_pos_embed + self.grad_checkpointing = False + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + # stage 1 + if self.vit_stem: + self.stem = None + self.patch_embed1 = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=embed_norm, + flatten=False, + ) + img_size = [x // patch_size for x in img_size] + else: + if self.init_channels is None: + self.stem = None + self.patch_embed1 = PatchEmbed( + img_size=img_size, + patch_size=patch_size // 2, + in_chans=in_chans, + embed_dim=embed_dim // 2, + norm_layer=embed_norm, + flatten=False, + ) + img_size = [x // (patch_size // 2) for x in img_size] + else: + self.stem = nn.Sequential( + nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(self.init_channels), + nn.ReLU(inplace=True) + ) + img_size = [x // 2 for x in img_size] + self.patch_embed1 = PatchEmbed( + img_size=img_size, + patch_size=patch_size // 4, + in_chans=self.init_channels, + embed_dim=embed_dim // 2, + norm_layer=embed_norm, + flatten=False, + ) + img_size = [x // (patch_size // 4) for x in img_size] + + if self.use_pos_embed: + if self.vit_stem: + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) + else: + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size)) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + else: + self.pos_embed1 = None + + self.stage1 = nn.Sequential(*[ + Block( + dim=embed_dim//2, + num_heads=num_heads, + head_dim_ratio=0.5, + mlp_ratio=mlp_ratio, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + group=group, + attn_disabled=(attn_stage[0] == '0'), + spatial_conv=(spatial_conv[0] == '1'), + ) + for i in range(self.stage_num1) + ]) + + # stage2 + if not self.vit_stem: + self.patch_embed2 = PatchEmbed( + img_size=img_size, + patch_size=patch_size // 8, + in_chans=embed_dim // 2, + embed_dim=embed_dim, + norm_layer=embed_norm, + flatten=False, + ) + img_size = [x // (patch_size // 8) for x in img_size] + if self.use_pos_embed: + self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) + else: + self.pos_embed2 = None + else: + self.patch_embed2 = None + self.stage2 = nn.Sequential(*[ + Block( + dim=embed_dim, + num_heads=num_heads, + head_dim_ratio=1.0, + mlp_ratio=mlp_ratio, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + group=group, + attn_disabled=(attn_stage[1] == '0'), + spatial_conv=(spatial_conv[1] == '1'), + ) + for i in range(self.stage_num1, self.stage_num1+self.stage_num2) + ]) + + # stage 3 + if not self.vit_stem: + self.patch_embed3 = PatchEmbed( + img_size=img_size, + patch_size=patch_size // 8, + in_chans=embed_dim, + embed_dim=embed_dim * 2, + norm_layer=embed_norm, + flatten=False, + ) + img_size = [x // (patch_size // 8) for x in img_size] + if self.use_pos_embed: + self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size)) + else: + self.pos_embed3 = None + else: + self.patch_embed3 = None + self.stage3 = nn.Sequential(*[ + Block( + dim=embed_dim * 2, + num_heads=num_heads, + head_dim_ratio=1.0, + mlp_ratio=mlp_ratio, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + group=group, + attn_disabled=(attn_stage[2] == '0'), + spatial_conv=(spatial_conv[2] == '1'), + ) + for i in range(self.stage_num1+self.stage_num2, depth) + ]) + + self.num_features = embed_dim if self.vit_stem else embed_dim * 2 + self.norm = norm_layer(self.num_features) + + # head + global_pool, head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + self.global_pool = global_pool + self.head_drop = nn.Dropout(drop_rate) + self.head = head + + # weights init + if self.use_pos_embed: + trunc_normal_(self.pos_embed1, std=0.02) + if not self.vit_stem: + trunc_normal_(self.pos_embed2, std=0.02) + trunc_normal_(self.pos_embed3, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + if self.conv_init: + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + else: + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0.) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^patch_embed1|pos_embed1|stem', # stem and embed + blocks=[ + (r'^stage(\d+)\.(\d+)' if coarse else r'^stage(\d+)\.(\d+)', None), + (r'^(?:patch_embed|pos_embed)(\d+)', (0,)), + (r'^norm', (99999,)) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + if self.stem is not None: + x = self.stem(x) + + # stage 1 + x = self.patch_embed1(x) + if self.pos_embed1 is not None: + x = self.pos_drop(x + self.pos_embed1) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stage1, x) + else: + x = self.stage1(x) + + # stage 2 + if self.patch_embed2 is not None: + x = self.patch_embed2(x) + if self.pos_embed2 is not None: + x = self.pos_drop(x + self.pos_embed2) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stage2, x) + else: + x = self.stage2(x) + + # stage3 + if self.patch_embed3 is not None: + x = self.patch_embed3(x) + if self.pos_embed3 is not None: + x = self.pos_drop(x + self.pos_embed3) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stage3, x) + else: + x = self.stage3(x) + + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + x = self.global_pool(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + model = build_model_with_cfg(Visformer, variant, pretrained, **kwargs) + return model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'visformer_tiny.in1k': _cfg(hf_hub_id='timm/'), + 'visformer_small.in1k': _cfg(hf_hub_id='timm/'), +}) + + +@register_model +def visformer_tiny(pretrained=False, **kwargs) -> Visformer: + model_cfg = dict( + init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, + attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, + embed_norm=nn.BatchNorm2d) + model = _create_visformer('visformer_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs)) + return model + + +@register_model +def visformer_small(pretrained=False, **kwargs) -> Visformer: + model_cfg = dict( + init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, + attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, + embed_norm=nn.BatchNorm2d) + model = _create_visformer('visformer_small', pretrained=pretrained, **dict(model_cfg, **kwargs)) + return model + + +# @register_model +# def visformer_net1(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net2(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net3(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net4(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net5(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', +# spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net6(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', +# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net7(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', +# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model + + + + diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/vision_transformer.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..70f91d588b78629bda6bd35313c8acc1624b0b04 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/vision_transformer.py @@ -0,0 +1,2696 @@ +""" Vision Transformer (ViT) in PyTorch + +A PyTorch implement of Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 + +`FlexiViT: One Model for All Patch Sizes` + - https://arxiv.org/abs/2212.08013 + +The official jax code is released and available at + * https://github.com/google-research/vision_transformer + * https://github.com/google-research/big_vision + +Acknowledgments: + * The paper authors for releasing code and weights, thanks! + * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch + * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT + * Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2020, Ross Wightman +""" +import logging +import math +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.jit import Final + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ + OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \ + trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ + get_act_layer, get_norm_layer, LayerType +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv +from ._registry import generate_default_cfgs, register_model, register_model_deprecations + +__all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this + + +_logger = logging.getLogger(__name__) + + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + init_values: Optional[float] = None, + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class ResPostBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + init_values: Optional[float] = None, + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.init_values = init_values + + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.init_weights() + + def init_weights(self) -> None: + # NOTE this init overrides that base model init with specific changes for the block type + if self.init_values is not None: + nn.init.constant_(self.norm1.weight, self.init_values) + nn.init.constant_(self.norm2.weight, self.init_values) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path1(self.norm1(self.attn(x))) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + return x + + +class ParallelScalingBlock(nn.Module): + """ Parallel ViT block (MLP & Attention in parallel) + Based on: + 'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442 + """ + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + init_values: Optional[float] = None, + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: Optional[nn.Module] = None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() + mlp_hidden_dim = int(mlp_ratio * dim) + in_proj_out_dim = mlp_hidden_dim + 3 * dim + + self.in_norm = norm_layer(dim) + self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias) + self.in_split = [mlp_hidden_dim] + [dim] * 3 + if qkv_bias: + self.register_buffer('qkv_bias', None) + self.register_parameter('mlp_bias', None) + else: + self.register_buffer('qkv_bias', torch.zeros(3 * dim), persistent=False) + self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim)) + + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.attn_out_proj = nn.Linear(dim, dim) + + self.mlp_drop = nn.Dropout(proj_drop) + self.mlp_act = act_layer() + self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim) + + self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + + # Combined MLP fc1 & qkv projections + y = self.in_norm(x) + if self.mlp_bias is not None: + # Concat constant zero-bias for qkv w/ trainable mlp_bias. + # Appears faster than adding to x_mlp separately + y = F.linear(y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias))) + else: + y = self.in_proj(y) + x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1) + + # Dot product attention w/ qk norm + q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2) + k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2) + v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) + if self.fused_attn: + x_attn = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x_attn = attn @ v + x_attn = x_attn.transpose(1, 2).reshape(B, N, C) + x_attn = self.attn_out_proj(x_attn) + + # MLP activation, dropout, fc2 + x_mlp = self.mlp_act(x_mlp) + x_mlp = self.mlp_drop(x_mlp) + x_mlp = self.mlp_out_proj(x_mlp) + + # Add residual w/ drop path & layer scale applied + y = self.drop_path(self.ls(x_attn + x_mlp)) + x = x + y + return x + + +class ParallelThingsBlock(nn.Module): + """ Parallel ViT block (N parallel attention followed by N parallel MLP) + Based on: + `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + """ + def __init__( + self, + dim: int, + num_heads: int, + num_parallel: int = 2, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + init_values: Optional[float] = None, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.num_parallel = num_parallel + self.attns = nn.ModuleList() + self.ffns = nn.ModuleList() + for _ in range(num_parallel): + self.attns.append(nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('attn', Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + )), + ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) + ]))) + self.ffns.append(nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('mlp', mlp_layer( + dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + )), + ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) + ]))) + + def _forward_jit(self, x: torch.Tensor) -> torch.Tensor: + x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) + x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) + return x + + @torch.jit.ignore + def _forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + sum(attn(x) for attn in self.attns) + x = x + sum(ffn(x) for ffn in self.ffns) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return self._forward_jit(x) + else: + return self._forward(x) + + +class VisionTransformer(nn.Module): + """ Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + dynamic_img_size: Final[bool] + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: Literal['', 'avg', 'token', 'map'] = 'token', + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_norm: bool = False, + init_values: Optional[float] = None, + class_token: bool = True, + no_embed_class: bool = False, + reg_tokens: int = 0, + pre_norm: bool = False, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, + drop_rate: float = 0., + pos_drop_rate: float = 0., + patch_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '', + fix_init: bool = False, + embed_layer: Callable = PatchEmbed, + norm_layer: Optional[LayerType] = None, + act_layer: Optional[LayerType] = None, + block_fn: Type[nn.Module] = Block, + mlp_layer: Type[nn.Module] = Mlp, + ) -> None: + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of image input channels. + num_classes: Mumber of classes for classification head. + global_pool: Type of global pooling for final sequence (default: 'token'). + embed_dim: Transformer embedding dimension. + depth: Depth of transformer. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: Enable bias for qkv projections if True. + init_values: Layer-scale init values (layer-scale enabled if not None). + class_token: Use class token. + no_embed_class: Don't include position embeddings for class (or reg) tokens. + reg_tokens: Number of register tokens. + fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + drop_rate: Head dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + weight_init: Weight initialization scheme. + fix_init: Apply weight initialization fix (scaling w/ layer index). + embed_layer: Patch embedding layer. + norm_layer: Normalization layer. + act_layer: MLP activation layer. + block_fn: Transformer block layer. + """ + super().__init__() + assert global_pool in ('', 'avg', 'token', 'map') + assert class_token or global_pool != 'token' + use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm + norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + act_layer = get_act_layer(act_layer) or nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg) + self.dynamic_img_size = dynamic_img_size + self.grad_checkpointing = False + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + **embed_args, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + init_values=init_values, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + mlp_layer=mlp_layer, + ) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # Classifier Head + if global_pool == 'map': + self.attn_pool = AttentionPoolLatent( + self.embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + ) + else: + self.attn_pool = None + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if weight_init != 'skip': + self.init_weights(weight_init) + if fix_init: + self.fix_init_weight() + + def fix_init_weight(self): + def rescale(param, _layer_id): + param.div_(math.sqrt(2.0 * _layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def init_weights(self, mode: str = '') -> None: + assert mode in ('jax', 'jax_nlhb', 'moco', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + trunc_normal_(self.pos_embed, std=.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(get_init_weights_vit(mode, head_bias), self) + + def _init_weights(self, m: nn.Module) -> None: + # this fn left here for compat with downstream users + init_weights_vit_timm(m) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None: + _load_weights(self, checkpoint_path, prefix) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + return {'pos_embed', 'cls_token', 'dist_token'} + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool = None) -> None: + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'token', 'map') + if global_pool == 'map' and self.attn_pool is None: + assert False, "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != 'map ' and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + if self.dynamic_img_size: + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = x.view(B, -1, C) + else: + pos_embed = self.pos_embed + + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + x = x + pos_embed + + return self.pos_drop(x) + + def _intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + ) -> List[torch.Tensor]: + outputs, num_blocks = [], len(self.blocks) + take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) + + # forward pass + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in take_indices: + outputs.append(x) + + return outputs + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + reshape: bool = False, + return_prefix_tokens: bool = False, + norm: bool = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + """ Intermediate layer accessor (NOTE: This is a WIP experiment). + Inspired by DINO / DINOv2 interface + """ + # take last n blocks if n is an int, if in is a sequence, select by matching indices + outputs = self._intermediate_layers(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs] + outputs = [out[:, self.num_prefix_tokens:] for out in outputs] + + if reshape: + grid_size = self.patch_embed.grid_size + outputs = [ + out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + + if return_prefix_tokens: + return tuple(zip(outputs, prefix_tokens)) + return tuple(outputs) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + x = self.norm(x) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + if self.attn_pool is not None: + x = self.attn_pool(x) + elif self.global_pool == 'avg': + x = x[:, self.num_prefix_tokens:].mean(dim=1) + elif self.global_pool: + x = x[:, 0] # class token + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def init_weights_vit_timm(module: nn.Module, name: str = '') -> None: + """ ViT weight initialization, original timm impl (for reproducibility) """ + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.0) -> None: + """ ViT weight initialization, matching JAX (Flax) impl """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def init_weights_vit_moco(module: nn.Module, name: str = '') -> None: + """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ + if isinstance(module, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) + nn.init.uniform_(module.weight, -val, val) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable: + if 'jax' in mode: + return partial(init_weights_vit_jax, head_bias=head_bias) + elif 'moco' in mode: + return init_weights_vit_moco + else: + return init_weights_vit_timm + + +def resize_pos_embed( + posemb: torch.Tensor, + posemb_new: torch.Tensor, + num_prefix_tokens: int = 1, + gs_new: Tuple[int, int] = (), + interpolation: str = 'bicubic', + antialias: bool = False, +) -> torch.Tensor: + """ Rescale the grid of position embeddings when loading from state_dict. + + *DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed + + Adapted from: + https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + """ + ntok_new = posemb_new.shape[1] + if num_prefix_tokens: + posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:] + ntok_new -= num_prefix_tokens + else: + posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + _logger.info(f'Resized position embedding: {posemb.shape} ({[gs_old, gs_old]}) to {posemb_new.shape} ({gs_new}).') + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=interpolation, antialias=antialias, align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) + return posemb + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = '') -> None: + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + interpolation = 'bilinear' + antialias = False + big_vision = False + if not prefix: + if 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + elif 'params/embedding/kernel' in w: + prefix = 'params/' + big_vision = True + elif 'params/img/embedding/kernel' in w: + prefix = 'params/img/' + big_vision = True + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]: + embed_conv_w = resample_patch_embed( + embed_conv_w, + model.patch_embed.proj.weight.shape[-2:], + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + if model.cls_token is not None: + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + if big_vision: + pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) + else: + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + old_shape = pos_embed_w.shape + num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) + pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, + new_size=model.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + if (isinstance(model.head, nn.Linear) and + f'{prefix}head/bias' in w and + model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]): + model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights + # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + if model.attn_pool is not None: + block_prefix = f'{prefix}MAPHead_0/' + mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' + model.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) + model.attn_pool.kv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) + model.attn_pool.kv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) + model.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) + model.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) + model.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + model.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + model.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + model.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + for r in range(2): + getattr(model.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) + getattr(model.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) + + mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) + + +def _convert_openai_clip( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, + prefix: str = 'visual.', +) -> Dict[str, torch.Tensor]: + out_dict = {} + swaps = [ + ('conv1', 'patch_embed.proj'), + ('positional_embedding', 'pos_embed'), + ('transformer.resblocks.', 'blocks.'), + ('ln_pre', 'norm_pre'), + ('ln_post', 'norm'), + ('ln_', 'norm'), + ('in_proj_', 'qkv.'), + ('out_proj', 'proj'), + ('mlp.c_fc', 'mlp.fc1'), + ('mlp.c_proj', 'mlp.fc2'), + ] + for k, v in state_dict.items(): + if not k.startswith(prefix): + continue + k = k.replace(prefix, '') + for sp in swaps: + k = k.replace(sp[0], sp[1]) + + if k == 'proj': + k = 'head.weight' + v = v.transpose(0, 1) + out_dict['head.bias'] = torch.zeros(v.shape[0]) + elif k == 'class_embedding': + k = 'cls_token' + v = v.unsqueeze(0).unsqueeze(1) + elif k == 'pos_embed': + v = v.unsqueeze(0) + if v.shape[1] != model.pos_embed.shape[1]: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed( + v, + model.pos_embed, + 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), + model.patch_embed.grid_size + ) + out_dict[k] = v + return out_dict + + +def _convert_dinov2( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, +) -> Dict[str, torch.Tensor]: + import re + out_dict = {} + state_dict.pop("mask_token", None) + if 'register_tokens' in state_dict: + # convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed) + out_dict['reg_token'] = state_dict.pop('register_tokens') + out_dict['cls_token'] = state_dict.pop('cls_token') + state_dict['pos_embed'][:, 0] + out_dict['pos_embed'] = state_dict.pop('pos_embed')[:, 1:] + for k, v in state_dict.items(): + if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k): + out_dict[k.replace("w12", "fc1")] = v + continue + elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k): + out_dict[k.replace("w3", "fc2")] = v + continue + out_dict[k] = v + return out_dict + + +def checkpoint_filter_fn( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, + adapt_layer_scale: bool = False, + interpolation: str = 'bicubic', + antialias: bool = True, +) -> Dict[str, torch.Tensor]: + """ convert patch embedding weight from manual patchify + linear proj to conv""" + import re + out_dict = {} + state_dict = state_dict.get('model', state_dict) + state_dict = state_dict.get('state_dict', state_dict) + prefix = '' + + if 'visual.class_embedding' in state_dict: + return _convert_openai_clip(state_dict, model) + elif 'module.visual.class_embedding' in state_dict: + return _convert_openai_clip(state_dict, model, prefix='module.visual.') + + if "mask_token" in state_dict: + state_dict = _convert_dinov2(state_dict, model) + + if "encoder" in state_dict: + state_dict = state_dict['encoder'] + prefix = 'module.' + + if 'visual.trunk.pos_embed' in state_dict: + # convert an OpenCLIP model with timm vision encoder + # FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) + prefix = 'visual.trunk.' + + if prefix: + # filter on & remove prefix string from keys + state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} + + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k: + O, I, H, W = model.patch_embed.proj.weight.shape + if len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + if v.shape[-1] != W or v.shape[-2] != H: + v = resample_patch_embed( + v, + (H, W), + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: + # To resize pos embedding when using model at different size from pretrained weights + num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) + v = resample_abs_pos_embed( + v, + new_size=model.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + elif adapt_layer_scale and 'gamma_' in k: + # remap layer-scale gamma into sub-module (deit3 models) + k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) + elif 'pre_logits' in k: + # NOTE representation layer removed as not used in latest 21k/1k pretrained weights + continue + out_dict[k] = v + return out_dict + + +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + return { + 'url': url, + 'num_classes': 1000, + 'input_size': (3, 224, 224), + 'pool_size': None, + 'crop_pct': 0.9, + 'interpolation': 'bicubic', + 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, + 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', + 'classifier': 'head', + **kwargs, + } + +default_cfgs = { + + # re-finetuned augreg 21k FT on in1k weights + 'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(), + 'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg( + hf_hub_id='timm/'), + + # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k + 'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + + # patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k + 'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + hf_hub_id='timm/'), + 'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0), + + # How to train your ViT (augreg) weights trained on in1k only + 'vit_small_patch16_224.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_small_patch16_384.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch32_224.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_base_patch32_384.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_224.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_base_patch16_384.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + + 'vit_large_patch14_224.untrained': _cfg(url=''), + 'vit_huge_patch14_224.untrained': _cfg(url=''), + 'vit_giant_patch14_224.untrained': _cfg(url=''), + 'vit_gigantic_patch14_224.untrained': _cfg(url=''), + + # patch models, imagenet21k (weights from official Google JAX impl), classifier not valid + 'vit_base_patch32_224.orig_in21k': _cfg( + #url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth', + hf_hub_id='timm/', + num_classes=0), + 'vit_base_patch16_224.orig_in21k': _cfg( + #url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth', + hf_hub_id='timm/', + num_classes=0), + 'vit_large_patch32_224.orig_in21k': _cfg( + #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', + hf_hub_id='timm/', + num_classes=0), + 'vit_large_patch16_224.orig_in21k': _cfg( + #url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth', + hf_hub_id='timm/', + num_classes=0), + 'vit_huge_patch14_224.orig_in21k': _cfg( + hf_hub_id='timm/', + num_classes=0), + + # How to train your ViT (augreg) weights, pretrained on in21k + 'vit_tiny_patch16_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_small_patch32_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_small_patch16_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_base_patch32_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_base_patch16_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_base_patch8_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_large_patch16_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + + # SAM trained models (https://arxiv.org/abs/2106.01548) + 'vit_base_patch32_224.sam_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True, + hf_hub_id='timm/'), + 'vit_base_patch16_224.sam_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True, + hf_hub_id='timm/'), + + # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) + 'vit_small_patch16_224.dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth', + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_small_patch8_224.dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth', + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_base_patch16_224.dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth', + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_base_patch8_224.dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + + # DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only) + 'vit_small_patch14_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_base_patch14_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_large_patch14_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_giant_patch14_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + + # DINOv2 pretrained w/ registers - https://arxiv.org/abs/2309.16588 (no classifier head, for fine-tune/features only) + 'vit_small_patch14_reg4_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_base_patch14_reg4_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_large_patch14_reg4_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_giant_patch14_reg4_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + + # ViT ImageNet-21K-P pretraining by MILL + 'vit_base_patch16_224_miil.in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth', + hf_hub_id='timm/', + mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221), + 'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth', + hf_hub_id='timm/', + mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'), + + # Custom timm variants + 'vit_base_patch16_rpn_224.sw_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth', + hf_hub_id='timm/'), + 'vit_medium_patch16_gap_240.sw_in12k': _cfg( + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821), + 'vit_medium_patch16_gap_256.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_medium_patch16_gap_384.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'), + 'vit_base_patch16_gap_224': _cfg(), + + # CLIP pretrained image tower and related fine-tuned weights + 'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch32_clip_384.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), + 'vit_base_patch32_clip_448.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)), + 'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), + 'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), + 'vit_large_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + 'vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + + 'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg( + # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k', # FIXME weight exists, need to push + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), + 'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + + 'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), + 'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + 'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg( + hf_hub_id='', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + + 'vit_base_patch32_clip_224.openai_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch16_clip_224.openai_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch16_clip_384.openai_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_large_patch14_clip_224.openai_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + + 'vit_base_patch32_clip_224.laion2b_ft_in12k': _cfg( + #hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k', # FIXME weight exists, need to push + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), + 'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), + 'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821), + 'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), + + 'vit_base_patch32_clip_224.openai_ft_in12k': _cfg( + # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k', # FIXME weight exists, need to push + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), + 'vit_base_patch16_clip_224.openai_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), + 'vit_large_patch14_clip_224.openai_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), + + 'vit_base_patch32_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_base_patch16_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_large_patch14_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768), + 'vit_huge_patch14_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), + 'vit_giant_patch14_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), + 'vit_gigantic_patch14_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280), + + 'vit_base_patch32_clip_224.datacompxl': _cfg( + hf_hub_id='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_base_patch32_clip_256.datacompxl': _cfg( + hf_hub_id='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 256, 256), num_classes=512), + 'vit_base_patch16_clip_224.datacompxl': _cfg( + hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_large_patch14_clip_224.datacompxl': _cfg( + hf_hub_id='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), + + 'vit_base_patch16_clip_224.dfn2b': _cfg( + hf_hub_id='apple/DFN2B-CLIP-ViT-B-16', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_large_patch14_clip_224.dfn2b': _cfg( + hf_hub_id='apple/DFN2B-CLIP-ViT-L-14', + hf_hub_filename='open_clip_pytorch_model.bin', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), + 'vit_huge_patch14_clip_224.dfn5b': _cfg( + hf_hub_id='apple/DFN5B-CLIP-ViT-H-14', + hf_hub_filename='open_clip_pytorch_model.bin', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), + 'vit_huge_patch14_clip_378.dfn5b': _cfg( + hf_hub_id='apple/DFN5B-CLIP-ViT-H-14-378', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + notes=('natively QuickGELU, use quickgelu model variant for original results',), + crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024), + + 'vit_base_patch32_clip_224.metaclip_2pt5b': _cfg( + hf_hub_id='facebook/metaclip-b32-fullcc2.5b', + hf_hub_filename='metaclip_b32_fullcc2.5b.bin', + license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_base_patch16_clip_224.metaclip_2pt5b': _cfg( + hf_hub_id='facebook/metaclip-b16-fullcc2.5b', + hf_hub_filename='metaclip_b16_fullcc2.5b.bin', + license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_large_patch14_clip_224.metaclip_2pt5b': _cfg( + hf_hub_id='facebook/metaclip-l14-fullcc2.5b', + hf_hub_filename='metaclip_l14_fullcc2.5b.bin', + license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), + 'vit_huge_patch14_clip_224.metaclip_2pt5b': _cfg( + hf_hub_id='facebook/metaclip-h14-fullcc2.5b', + hf_hub_filename='metaclip_h14_fullcc2.5b.bin', + license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), + + 'vit_base_patch32_clip_224.openai': _cfg( + hf_hub_id='timm/vit_base_patch32_clip_224.openai', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_base_patch16_clip_224.openai': _cfg( + hf_hub_id='timm/vit_base_patch16_clip_224.openai', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_large_patch14_clip_224.openai': _cfg( + hf_hub_id='timm/vit_large_patch14_clip_224.openai', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), + 'vit_large_patch14_clip_336.openai': _cfg( + hf_hub_id='timm/vit_large_patch14_clip_336.openai', hf_hub_filename='open_clip_pytorch_model.bin', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), num_classes=768), + + # experimental (may be removed) + 'vit_base_patch32_plus_256.untrained': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_base_patch16_plus_240.untrained': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), + 'vit_small_patch16_36x1_224.untrained': _cfg(url=''), + 'vit_small_patch16_18x2_224.untrained': _cfg(url=''), + 'vit_base_patch16_18x2_224.untrained': _cfg(url=''), + + # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain + # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip + 'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt', + hf_hub_id='timm/', license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), crop_pct=1.0), + 'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt', + hf_hub_id='timm/', license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + 'eva_large_patch14_196.in22k_ft_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt', + hf_hub_id='timm/', license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), crop_pct=1.0), + 'eva_large_patch14_336.in22k_ft_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt', + hf_hub_id='timm/', license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + + 'flexivit_small.1200ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_small.600ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_600ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_small.300ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + + 'flexivit_base.1200ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_base.600ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_600ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_base.300ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_base.1000ep_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_1000ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + 'flexivit_base.300ep_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + + 'flexivit_large.1200ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_large.600ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_600ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_large.300ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + + 'flexivit_base.patch16_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/vit_b16_i21k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + 'flexivit_base.patch30_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + + 'vit_base_patch16_xp_224.untrained': _cfg(url=''), + 'vit_large_patch14_xp_224.untrained': _cfg(url=''), + 'vit_huge_patch14_xp_224.untrained': _cfg(url=''), + + 'vit_base_patch16_224.mae': _cfg( + url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth', + hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_large_patch16_224.mae': _cfg( + url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth', + hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_huge_patch14_224.mae': _cfg( + url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_huge.pth', + hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + + 'vit_huge_patch14_gap_224.in1k_ijepa': _cfg( + url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar', + # hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_huge_patch14_gap_224.in22k_ijepa': _cfg( + url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar', + # hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_huge_patch16_gap_448.in1k_ijepa': _cfg( + url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar', + # hf_hub_id='timm/', + license='cc-by-nc-4.0', + input_size=(3, 448, 448), crop_pct=1.0, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_giant_patch16_gap_224.in22k_ijepa': _cfg( + url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar', + # hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + + 'vit_base_patch16_siglip_224.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0), + 'vit_base_patch16_siglip_256.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-256', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 256, 256), + num_classes=0), + 'vit_base_patch16_siglip_384.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 384, 384), + num_classes=0), + 'vit_base_patch16_siglip_512.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-512', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 512, 512), + num_classes=0), + 'vit_large_patch16_siglip_256.webli': _cfg( + hf_hub_id='timm/ViT-L-16-SigLIP-256', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 256, 256), + num_classes=0), + 'vit_large_patch16_siglip_384.webli': _cfg( + hf_hub_id='timm/ViT-L-16-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 384, 384), + num_classes=0), + 'vit_so400m_patch14_siglip_224.webli': _cfg( + hf_hub_id='timm/ViT-SO400M-14-SigLIP', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0), + 'vit_so400m_patch14_siglip_384.webli': _cfg( + hf_hub_id='timm/ViT-SO400M-14-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 384, 384), + num_classes=0), + + 'vit_medium_patch16_reg4_256': _cfg( + input_size=(3, 256, 256)), + 'vit_medium_patch16_reg4_gap_256': _cfg( + input_size=(3, 256, 256)), + 'vit_base_patch16_reg4_gap_256': _cfg( + input_size=(3, 256, 256)), + 'vit_so150m_patch16_reg4_gap_256': _cfg( + input_size=(3, 256, 256)), + 'vit_so150m_patch16_reg4_map_256': _cfg( + input_size=(3, 256, 256)), +} + +_quick_gelu_cfgs = [ + 'vit_large_patch14_clip_224.dfn2b', + 'vit_huge_patch14_clip_224.dfn5b', + 'vit_huge_patch14_clip_378.dfn5b', + 'vit_base_patch32_clip_224.metaclip_2pt5b', + 'vit_base_patch16_clip_224.metaclip_2pt5b', + 'vit_large_patch14_clip_224.metaclip_2pt5b', + 'vit_huge_patch14_clip_224.metaclip_2pt5b', + 'vit_base_patch32_clip_224.openai', + 'vit_base_patch16_clip_224.openai', + 'vit_large_patch14_clip_224.openai', + 'vit_large_patch14_clip_336.openai', +] +default_cfgs.update({ + n.replace('_clip_', '_clip_quickgelu_'): default_cfgs[n] for n in _quick_gelu_cfgs +}) +default_cfgs = generate_default_cfgs(default_cfgs) + + +def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) -> VisionTransformer: + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + if 'flexi' in variant: + # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed + # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. + _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False) + else: + _filter_fn = checkpoint_filter_fn + + # FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln? + strict = True + if 'siglip' in variant and kwargs.get('global_pool', None) != 'map': + strict = False + + return build_model_with_cfg( + VisionTransformer, + variant, + pretrained, + pretrained_filter_fn=_filter_fn, + pretrained_strict=strict, + **kwargs, + ) + + +@register_model +def vit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Tiny (Vit-Ti/16) + """ + model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) + model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_tiny_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Tiny (Vit-Ti/16) @ 384x384. + """ + model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) + model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Small (ViT-S/32) + """ + model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Small (ViT-S/32) at 384x384. + """ + model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Small (ViT-S/16) + """ + model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Small (ViT-S/16) + """ + model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Small (ViT-S/8) + """ + model_args = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. + """ + model_args = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/14) + """ + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). + """ + model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16) + model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + """ + model_args = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16) + model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_gigantic_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + """ + model_args = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16) + model = _create_vision_transformer( + 'vit_gigantic_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_224_miil(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False) + model = _create_vision_transformer( + 'vit_base_patch16_224_miil', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_gap_240(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240 + """ + model_args = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) + model = _create_vision_transformer( + 'vit_medium_patch16_gap_240', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256 + """ + model_args = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) + model = _create_vision_transformer( + 'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384 + """ + model_args = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) + model = _create_vision_transformer( + 'vit_medium_patch16_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224 + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_base_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) w/ no class token, avg pool + """ + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_huge_patch14_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch16_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448 + """ + model_args = dict( + patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_huge_patch16_gap_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool + """ + model_args = dict( + patch_size=16, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, + class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_giant_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/32 CLIP image tower @ 224x224 + """ + model_args = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_clip_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/32 CLIP image tower @ 256x256 + """ + model_args = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch32_clip_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/32 CLIP image tower @ 384x384 + """ + model_args = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch32_clip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_clip_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/32 CLIP image tower @ 448x448 + """ + model_args = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch32_clip_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/16 CLIP image tower + """ + model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/16 CLIP image tower @ 384x384 + """ + model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch16_clip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/14) CLIP image tower + """ + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 + """ + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) CLIP image tower. + """ + model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336 + """ + model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_huge_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_clip_378(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 + """ + model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_huge_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giant_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + Pretrained weights from CLIP image tower. + """ + model_args = dict( + patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_giant_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_gigantic_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + Pretrained weights from CLIP image tower. + """ + model_args = dict( + patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/32 CLIP image tower @ 224x224 + """ + model_args = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, + norm_layer=nn.LayerNorm, act_layer='quick_gelu') + model = _create_vision_transformer( + 'vit_base_patch32_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/16 CLIP image tower w/ QuickGELU act + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, + norm_layer=nn.LayerNorm, act_layer='quick_gelu') + model = _create_vision_transformer( + 'vit_base_patch16_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act + """ + from timm.layers import get_act_layer + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, + norm_layer=nn.LayerNorm, act_layer='quick_gelu') + model = _create_vision_transformer( + 'vit_large_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_clip_quickgelu_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 w/ QuickGELU act + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, + norm_layer=nn.LayerNorm, act_layer='quick_gelu') + model = _create_vision_transformer( + 'vit_large_patch14_clip_quickgelu_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) CLIP image tower w/ QuickGELU act. + """ + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, + norm_layer=nn.LayerNorm, act_layer='quick_gelu') + model = _create_vision_transformer( + 'vit_huge_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_clip_quickgelu_378(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 w/ QuickGELU act + """ + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, + norm_layer=nn.LayerNorm, act_layer='quick_gelu') + model = _create_vision_transformer( + 'vit_huge_patch14_clip_quickgelu_378', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +# Experimental models below + +@register_model +def vit_base_patch32_plus_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/32+) + """ + model_args = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5) + model = _create_vision_transformer( + 'vit_base_patch32_plus_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_plus_240(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/16+) + """ + model_args = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5) + model = _create_vision_transformer( + 'vit_base_patch16_plus_240', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_rpn_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/16) w/ residual post-norm + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, + class_token=False, block_fn=ResPostBlock, global_pool='avg') + model = _create_vision_transformer( + 'vit_base_patch16_rpn_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch16_36x1_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. + """ + model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5) + model = _create_vision_transformer( + 'vit_small_patch16_36x1_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. + """ + model_args = dict( + patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelThingsBlock) + model = _create_vision_transformer( + 'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelThingsBlock) + model = _create_vision_transformer( + 'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def eva_large_patch14_196(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain""" + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg') + model = _create_vision_transformer( + 'eva_large_patch14_196', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def eva_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain""" + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg') + model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def flexivit_small(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ FlexiViT-Small + """ + model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True) + model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def flexivit_base(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ FlexiViT-Base + """ + model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True) + model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def flexivit_large(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ FlexiViT-Large + """ + model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True) + model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled. + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, no_embed_class=True, + norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True, + ) + model = _create_vision_transformer( + 'vit_base_patch16_xp_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled. + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, no_embed_class=True, + norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True, + ) + model = _create_vision_transformer( + 'vit_large_patch14_xp_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) w/ parallel blocks and qk norm enabled. + """ + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, no_embed_class=True, + norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True, + ) + model = _create_vision_transformer( + 'vit_huge_patch14_xp_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-S/14 for DINOv2 + """ + model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518) + model = _create_vision_transformer( + 'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/14 for DINOv2 + """ + model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518) + model = _create_vision_transformer( + 'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-L/14 for DINOv2 + """ + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518) + model = _create_vision_transformer( + 'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giant_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-G/14 for DINOv2 + """ + # The hidden_features of SwiGLU is calculated by: + # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + # When embed_dim=1536, hidden_features=4096 + # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192 + model_args = dict( + patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, + mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU + ) + model = _create_vision_transformer( + 'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-S/14 for DINOv2 w/ 4 registers + """ + model_args = dict( + patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, + reg_tokens=4, no_embed_class=True, + ) + model = _create_vision_transformer( + 'vit_small_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/14 for DINOv2 w/ 4 registers + """ + model_args = dict( + patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, + reg_tokens=4, no_embed_class=True, + ) + model = _create_vision_transformer( + 'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-L/14 for DINOv2 w/ 4 registers + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, + reg_tokens=4, no_embed_class=True, + ) + model = _create_vision_transformer( + 'vit_large_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giant_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-G/14 for DINOv2 + """ + # The hidden_features of SwiGLU is calculated by: + # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + # When embed_dim=1536, hidden_features=4096 + # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192 + model_args = dict( + patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, mlp_ratio=2.66667 * 2, + mlp_layer=SwiGLUPacked, act_layer=nn.SiLU, reg_tokens=4, no_embed_class=True, + ) + model = _create_vision_transformer( + 'vit_giant_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True, + no_embed_class=True, reg_tokens=4, + ) + model = _create_vision_transformer( + 'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, + class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_medium_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, + no_embed_class=True, global_pool='avg', reg_tokens=4, + ) + model = _create_vision_transformer( + 'vit_base_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572, + class_token=False, reg_tokens=4, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_so150m_patch16_reg4_map_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572, + class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so150m_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +register_model_deprecations(__name__, { + 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k', + 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k', + 'vit_small_patch16_224_in21k': 'vit_small_patch16_224.augreg_in21k', + 'vit_base_patch32_224_in21k': 'vit_base_patch32_224.augreg_in21k', + 'vit_base_patch16_224_in21k': 'vit_base_patch16_224.augreg_in21k', + 'vit_base_patch8_224_in21k': 'vit_base_patch8_224.augreg_in21k', + 'vit_large_patch32_224_in21k': 'vit_large_patch32_224.orig_in21k', + 'vit_large_patch16_224_in21k': 'vit_large_patch16_224.augreg_in21k', + 'vit_huge_patch14_224_in21k': 'vit_huge_patch14_224.orig_in21k', + 'vit_base_patch32_224_sam': 'vit_base_patch32_224.sam', + 'vit_base_patch16_224_sam': 'vit_base_patch16_224.sam', + 'vit_small_patch16_224_dino': 'vit_small_patch16_224.dino', + 'vit_small_patch8_224_dino': 'vit_small_patch8_224.dino', + 'vit_base_patch16_224_dino': 'vit_base_patch16_224.dino', + 'vit_base_patch8_224_dino': 'vit_base_patch8_224.dino', + 'vit_base_patch16_224_miil_in21k': 'vit_base_patch16_224_miil.in21k', + 'vit_base_patch32_224_clip_laion2b': 'vit_base_patch32_clip_224.laion2b', + 'vit_large_patch14_224_clip_laion2b': 'vit_large_patch14_clip_224.laion2b', + 'vit_huge_patch14_224_clip_laion2b': 'vit_huge_patch14_clip_224.laion2b', + 'vit_giant_patch14_224_clip_laion2b': 'vit_giant_patch14_clip_224.laion2b', +}) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/volo.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/volo.py new file mode 100644 index 0000000000000000000000000000000000000000..260cd20dcc5a6de99f809a797a8be1ec300317d8 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/volo.py @@ -0,0 +1,897 @@ +""" Vision OutLOoker (VOLO) implementation + +Paper: `VOLO: Vision Outlooker for Visual Recognition` - https://arxiv.org/abs/2106.13112 + +Code adapted from official impl at https://github.com/sail-sg/volo, original copyright in comment below + +Modifications and additions for timm by / Copyright 2022, Ross Wightman +""" +# Copyright 2021 Sea Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_ +from ._builder import build_model_with_cfg +from ._registry import register_model, generate_default_cfgs + +__all__ = ['VOLO'] # model_registry will add each entrypoint fn to this + + +class OutlookAttention(nn.Module): + + def __init__( + self, + dim, + num_heads, + kernel_size=3, + padding=1, + stride=1, + qkv_bias=False, + attn_drop=0., + proj_drop=0., + ): + super().__init__() + head_dim = dim // num_heads + self.num_heads = num_heads + self.kernel_size = kernel_size + self.padding = padding + self.stride = stride + self.scale = head_dim ** -0.5 + + self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.attn = nn.Linear(dim, kernel_size ** 4 * num_heads) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride) + self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True) + + def forward(self, x): + B, H, W, C = x.shape + + v = self.v(x).permute(0, 3, 1, 2) # B, C, H, W + + h, w = math.ceil(H / self.stride), math.ceil(W / self.stride) + v = self.unfold(v).reshape( + B, self.num_heads, C // self.num_heads, + self.kernel_size * self.kernel_size, h * w).permute(0, 1, 4, 3, 2) # B,H,N,kxk,C/H + + attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + attn = self.attn(attn).reshape( + B, h * w, self.num_heads, self.kernel_size * self.kernel_size, + self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4) # B,H,N,kxk,kxk + attn = attn * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size * self.kernel_size, h * w) + x = F.fold(x, output_size=(H, W), kernel_size=self.kernel_size, padding=self.padding, stride=self.stride) + + x = self.proj(x.permute(0, 2, 3, 1)) + x = self.proj_drop(x) + + return x + + +class Outlooker(nn.Module): + def __init__( + self, + dim, + kernel_size, + padding, + stride=1, + num_heads=1, + mlp_ratio=3., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + qkv_bias=False, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = OutlookAttention( + dim, + num_heads, + kernel_size=kernel_size, + padding=padding, + stride=stride, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + ) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Attention(nn.Module): + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0., + proj_drop=0., + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, H, W, C = x.shape + + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, H, W, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Transformer(nn.Module): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop) + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class ClassAttention(nn.Module): + + def __init__( + self, + dim, + num_heads=8, + head_dim=None, + qkv_bias=False, + attn_drop=0., + proj_drop=0., + ): + super().__init__() + self.num_heads = num_heads + if head_dim is not None: + self.head_dim = head_dim + else: + head_dim = dim // num_heads + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + + self.kv = nn.Linear(dim, self.head_dim * self.num_heads * 2, bias=qkv_bias) + self.q = nn.Linear(dim, self.head_dim * self.num_heads, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(self.head_dim * self.num_heads, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + + kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) + q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim) + attn = ((q * self.scale) @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + cls_embed = (attn @ v).transpose(1, 2).reshape(B, 1, self.head_dim * self.num_heads) + cls_embed = self.proj(cls_embed) + cls_embed = self.proj_drop(cls_embed) + return cls_embed + + +class ClassBlock(nn.Module): + + def __init__( + self, + dim, + num_heads, + head_dim=None, + mlp_ratio=4., + qkv_bias=False, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = ClassAttention( + dim, + num_heads=num_heads, + head_dim=head_dim, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + # NOTE: drop path for stochastic depth + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, x): + cls_embed = x[:, :1] + cls_embed = cls_embed + self.drop_path(self.attn(self.norm1(x))) + cls_embed = cls_embed + self.drop_path(self.mlp(self.norm2(cls_embed))) + return torch.cat([cls_embed, x[:, 1:]], dim=1) + + +def get_block(block_type, **kargs): + if block_type == 'ca': + return ClassBlock(**kargs) + + +def rand_bbox(size, lam, scale=1): + """ + get bounding box as token labeling (https://github.com/zihangJiang/TokenLabeling) + return: bounding box + """ + W = size[1] // scale + H = size[2] // scale + cut_rat = np.sqrt(1. - lam) + cut_w = (W * cut_rat).astype(int) + cut_h = (H * cut_rat).astype(int) + + # uniform + cx = np.random.randint(W) + cy = np.random.randint(H) + + bbx1 = np.clip(cx - cut_w // 2, 0, W) + bby1 = np.clip(cy - cut_h // 2, 0, H) + bbx2 = np.clip(cx + cut_w // 2, 0, W) + bby2 = np.clip(cy + cut_h // 2, 0, H) + + return bbx1, bby1, bbx2, bby2 + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding. + Different with ViT use 1 conv layer, we use 4 conv layers to do patch embedding + """ + + def __init__( + self, + img_size=224, + stem_conv=False, + stem_stride=1, + patch_size=8, + in_chans=3, + hidden_dim=64, + embed_dim=384, + ): + super().__init__() + assert patch_size in [4, 8, 16] + if stem_conv: + self.conv = nn.Sequential( + nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112 + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112 + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112 + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + ) + else: + self.conv = None + + self.proj = nn.Conv2d( + hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride) + self.num_patches = (img_size // patch_size) * (img_size // patch_size) + + def forward(self, x): + if self.conv is not None: + x = self.conv(x) + x = self.proj(x) # B, C, H, W + return x + + +class Downsample(nn.Module): + """ Image to Patch Embedding, downsampling between stage1 and stage2 + """ + + def __init__(self, in_embed_dim, out_embed_dim, patch_size=2): + super().__init__() + self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = x.permute(0, 3, 1, 2) + x = self.proj(x) # B, C, H, W + x = x.permute(0, 2, 3, 1) + return x + + +def outlooker_blocks( + block_fn, + index, + dim, + layers, + num_heads=1, + kernel_size=3, + padding=1, + stride=2, + mlp_ratio=3., + qkv_bias=False, + attn_drop=0, + drop_path_rate=0., + **kwargs, +): + """ + generate outlooker layer in stage1 + return: outlooker layers + """ + blocks = [] + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) + blocks.append(block_fn( + dim, + kernel_size=kernel_size, + padding=padding, + stride=stride, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + drop_path=block_dpr, + )) + blocks = nn.Sequential(*blocks) + return blocks + + +def transformer_blocks( + block_fn, + index, + dim, + layers, + num_heads, + mlp_ratio=3., + qkv_bias=False, + attn_drop=0, + drop_path_rate=0., + **kwargs, +): + """ + generate transformer layers in stage2 + return: transformer layers + """ + blocks = [] + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) + blocks.append(block_fn( + dim, + num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + drop_path=block_dpr, + )) + blocks = nn.Sequential(*blocks) + return blocks + + +class VOLO(nn.Module): + """ + Vision Outlooker, the main class of our model + """ + + def __init__( + self, + layers, + img_size=224, + in_chans=3, + num_classes=1000, + global_pool='token', + patch_size=8, + stem_hidden_dim=64, + embed_dims=None, + num_heads=None, + downsamples=(True, False, False, False), + outlook_attention=(True, False, False, False), + mlp_ratio=3.0, + qkv_bias=False, + drop_rate=0., + pos_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=nn.LayerNorm, + post_layers=('ca', 'ca'), + use_aux_head=True, + use_mix_token=False, + pooling_scale=2, + ): + super().__init__() + num_layers = len(layers) + mlp_ratio = to_ntuple(num_layers)(mlp_ratio) + img_size = to_2tuple(img_size) + + self.num_classes = num_classes + self.global_pool = global_pool + self.mix_token = use_mix_token + self.pooling_scale = pooling_scale + self.num_features = embed_dims[-1] + if use_mix_token: # enable token mixing, see token labeling for details. + self.beta = 1.0 + assert global_pool == 'token', "return all tokens if mix_token is enabled" + self.grad_checkpointing = False + + self.patch_embed = PatchEmbed( + stem_conv=True, + stem_stride=2, + patch_size=patch_size, + in_chans=in_chans, + hidden_dim=stem_hidden_dim, + embed_dim=embed_dims[0], + ) + + # inital positional encoding, we add positional encoding after outlooker blocks + patch_grid = (img_size[0] // patch_size // pooling_scale, img_size[1] // patch_size // pooling_scale) + self.pos_embed = nn.Parameter(torch.zeros(1, patch_grid[0], patch_grid[1], embed_dims[-1])) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + + # set the main block in network + network = [] + for i in range(len(layers)): + if outlook_attention[i]: + # stage 1 + stage = outlooker_blocks( + Outlooker, + i, + embed_dims[i], + layers, + num_heads[i], + mlp_ratio=mlp_ratio[i], + qkv_bias=qkv_bias, + attn_drop=attn_drop_rate, + norm_layer=norm_layer, + ) + network.append(stage) + else: + # stage 2 + stage = transformer_blocks( + Transformer, + i, + embed_dims[i], + layers, + num_heads[i], + mlp_ratio=mlp_ratio[i], + qkv_bias=qkv_bias, + drop_path_rate=drop_path_rate, + attn_drop=attn_drop_rate, + norm_layer=norm_layer, + ) + network.append(stage) + + if downsamples[i]: + # downsampling between two stages + network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2)) + + self.network = nn.ModuleList(network) + + # set post block, for example, class attention layers + self.post_network = None + if post_layers is not None: + self.post_network = nn.ModuleList([ + get_block( + post_layers[i], + dim=embed_dims[-1], + num_heads=num_heads[-1], + mlp_ratio=mlp_ratio[-1], + qkv_bias=qkv_bias, + attn_drop=attn_drop_rate, + drop_path=0., + norm_layer=norm_layer) + for i in range(len(post_layers)) + ]) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1])) + trunc_normal_(self.cls_token, std=.02) + + # set output type + if use_aux_head: + self.aux_head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + else: + self.aux_head = None + self.norm = norm_layer(self.num_features) + + # Classifier head + self.head_drop = nn.Dropout(drop_rate) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[ + (r'^network\.(\d+)\.(\d+)', None), + (r'^network\.(\d+)', (0,)), + ], + blocks2=[ + (r'^cls_token', (0,)), + (r'^post_network\.(\d+)', None), + (r'^norm', (99999,)) + ], + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + if self.aux_head is not None: + self.aux_head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_tokens(self, x): + for idx, block in enumerate(self.network): + if idx == 2: + # add positional encoding after outlooker blocks + x = x + self.pos_embed + x = self.pos_drop(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(block, x) + else: + x = block(x) + + B, H, W, C = x.shape + x = x.reshape(B, -1, C) + return x + + def forward_cls(self, x): + B, N, C = x.shape + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat([cls_tokens, x], dim=1) + for block in self.post_network: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(block, x) + else: + x = block(x) + return x + + def forward_train(self, x): + """ A separate forward fn for training with mix_token (if a train script supports). + Combining multiple modes in as single forward with different return types is torchscript hell. + """ + x = self.patch_embed(x) + x = x.permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C + + # mix token, see token labeling for details. + if self.mix_token and self.training: + lam = np.random.beta(self.beta, self.beta) + patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[2] // self.pooling_scale + bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale) + temp_x = x.clone() + sbbx1, sbby1 = self.pooling_scale * bbx1, self.pooling_scale * bby1 + sbbx2, sbby2 = self.pooling_scale * bbx2, self.pooling_scale * bby2 + temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :] + x = temp_x + else: + bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0 + + # step2: tokens learning in the two stages + x = self.forward_tokens(x) + + # step3: post network, apply class attention or not + if self.post_network is not None: + x = self.forward_cls(x) + x = self.norm(x) + + if self.global_pool == 'avg': + x_cls = x.mean(dim=1) + elif self.global_pool == 'token': + x_cls = x[:, 0] + else: + x_cls = x + + if self.aux_head is None: + return x_cls + + x_aux = self.aux_head(x[:, 1:]) # generate classes in all feature tokens, see token labeling + if not self.training: + return x_cls + 0.5 * x_aux.max(1)[0] + + if self.mix_token and self.training: # reverse "mix token", see token labeling for details. + x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1]) + temp_x = x_aux.clone() + temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :] + x_aux = temp_x + x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1]) + + # return these: 1. class token, 2. classes from all feature tokens, 3. bounding box + return x_cls, x_aux, (bbx1, bby1, bbx2, bby2) + + def forward_features(self, x): + x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C + + # step2: tokens learning in the two stages + x = self.forward_tokens(x) + + # step3: post network, apply class attention or not + if self.post_network is not None: + x = self.forward_cls(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + out = x.mean(dim=1) + elif self.global_pool == 'token': + out = x[:, 0] + else: + out = x + x = self.head_drop(x) + if pre_logits: + return out + out = self.head(out) + if self.aux_head is not None: + # generate classes in all feature tokens, see token labeling + aux = self.aux_head(x[:, 1:]) + out = out + 0.5 * aux.max(1)[0] + return out + + def forward(self, x): + """ simplified forward (without mix token training) """ + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_volo(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + return build_model_with_cfg( + VOLO, + variant, + pretrained, + **kwargs, + ) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .96, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.conv.0', 'classifier': ('head', 'aux_head'), + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'volo_d1_224.sail_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar', + crop_pct=0.96), + 'volo_d1_384.sail_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar', + crop_pct=1.0, input_size=(3, 384, 384)), + 'volo_d2_224.sail_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar', + crop_pct=0.96), + 'volo_d2_384.sail_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar', + crop_pct=1.0, input_size=(3, 384, 384)), + 'volo_d3_224.sail_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar', + crop_pct=0.96), + 'volo_d3_448.sail_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar', + crop_pct=1.0, input_size=(3, 448, 448)), + 'volo_d4_224.sail_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar', + crop_pct=0.96), + 'volo_d4_448.sail_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar', + crop_pct=1.15, input_size=(3, 448, 448)), + 'volo_d5_224.sail_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar', + crop_pct=0.96), + 'volo_d5_448.sail_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar', + crop_pct=1.15, input_size=(3, 448, 448)), + 'volo_d5_512.sail_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar', + crop_pct=1.15, input_size=(3, 512, 512)), +}) + + +@register_model +def volo_d1_224(pretrained=False, **kwargs) -> VOLO: + """ VOLO-D1 model, Params: 27M """ + model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs) + model = _create_volo('volo_d1_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d1_384(pretrained=False, **kwargs) -> VOLO: + """ VOLO-D1 model, Params: 27M """ + model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs) + model = _create_volo('volo_d1_384', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d2_224(pretrained=False, **kwargs) -> VOLO: + """ VOLO-D2 model, Params: 59M """ + model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_volo('volo_d2_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d2_384(pretrained=False, **kwargs) -> VOLO: + """ VOLO-D2 model, Params: 59M """ + model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_volo('volo_d2_384', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d3_224(pretrained=False, **kwargs) -> VOLO: + """ VOLO-D3 model, Params: 86M """ + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_volo('volo_d3_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d3_448(pretrained=False, **kwargs) -> VOLO: + """ VOLO-D3 model, Params: 86M """ + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_volo('volo_d3_448', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d4_224(pretrained=False, **kwargs) -> VOLO: + """ VOLO-D4 model, Params: 193M """ + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs) + model = _create_volo('volo_d4_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d4_448(pretrained=False, **kwargs) -> VOLO: + """ VOLO-D4 model, Params: 193M """ + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs) + model = _create_volo('volo_d4_448', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d5_224(pretrained=False, **kwargs) -> VOLO: + """ VOLO-D5 model, Params: 296M + stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5 + """ + model_args = dict( + layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), + mlp_ratio=4, stem_hidden_dim=128, **kwargs) + model = _create_volo('volo_d5_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d5_448(pretrained=False, **kwargs) -> VOLO: + """ VOLO-D5 model, Params: 296M + stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5 + """ + model_args = dict( + layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), + mlp_ratio=4, stem_hidden_dim=128, **kwargs) + model = _create_volo('volo_d5_448', pretrained=pretrained, **model_args) + return model + + +@register_model +def volo_d5_512(pretrained=False, **kwargs) -> VOLO: + """ VOLO-D5 model, Params: 296M + stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5 + """ + model_args = dict( + layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), + mlp_ratio=4, stem_hidden_dim=128, **kwargs) + model = _create_volo('volo_d5_512', pretrained=pretrained, **model_args) + return model diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/xception_aligned.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/xception_aligned.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b284255cb68af9b4e03b6cfba92fd87aa71963 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/xception_aligned.py @@ -0,0 +1,434 @@ +"""Pytorch impl of Aligned Xception 41, 65, 71 + +This is a correct, from scratch impl of Aligned Xception (Deeplab) models compatible with TF weights at +https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md + +Hacked together by / Copyright 2020 Ross Wightman +""" +from functools import partial +from typing import List, Dict, Type, Optional + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.layers import ClassifierHead, ConvNormAct, DropPath, PadType, create_conv2d, get_norm_act_layer +from timm.layers.helpers import to_3tuple +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs + +__all__ = ['XceptionAligned'] + + +class SeparableConv2d(nn.Module): + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + padding: PadType = '', + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + ): + super(SeparableConv2d, self).__init__() + self.kernel_size = kernel_size + self.dilation = dilation + + # depthwise convolution + self.conv_dw = create_conv2d( + in_chs, in_chs, kernel_size, stride=stride, + padding=padding, dilation=dilation, depthwise=True) + self.bn_dw = norm_layer(in_chs) + self.act_dw = act_layer(inplace=True) if act_layer is not None else nn.Identity() + + # pointwise convolution + self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1) + self.bn_pw = norm_layer(out_chs) + self.act_pw = act_layer(inplace=True) if act_layer is not None else nn.Identity() + + def forward(self, x): + x = self.conv_dw(x) + x = self.bn_dw(x) + x = self.act_dw(x) + x = self.conv_pw(x) + x = self.bn_pw(x) + x = self.act_pw(x) + return x + + +class PreSeparableConv2d(nn.Module): + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + padding: PadType = '', + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + first_act: bool = True, + ): + super(PreSeparableConv2d, self).__init__() + norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer) + self.kernel_size = kernel_size + self.dilation = dilation + + self.norm = norm_act_layer(in_chs, inplace=True) if first_act else nn.Identity() + # depthwise convolution + self.conv_dw = create_conv2d( + in_chs, in_chs, kernel_size, stride=stride, + padding=padding, dilation=dilation, depthwise=True) + + # pointwise convolution + self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1) + + def forward(self, x): + x = self.norm(x) + x = self.conv_dw(x) + x = self.conv_pw(x) + return x + + +class XceptionModule(nn.Module): + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 1, + dilation: int = 1, + pad_type: PadType = '', + start_with_relu: bool = True, + no_skip: bool = False, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Optional[Type[nn.Module]] = None, + drop_path: Optional[nn.Module] = None + ): + super(XceptionModule, self).__init__() + out_chs = to_3tuple(out_chs) + self.in_channels = in_chs + self.out_channels = out_chs[-1] + self.no_skip = no_skip + if not no_skip and (self.out_channels != self.in_channels or stride != 1): + self.shortcut = ConvNormAct( + in_chs, self.out_channels, 1, stride=stride, norm_layer=norm_layer, apply_act=False) + else: + self.shortcut = None + + separable_act_layer = None if start_with_relu else act_layer + self.stack = nn.Sequential() + for i in range(3): + if start_with_relu: + self.stack.add_module(f'act{i + 1}', act_layer(inplace=i > 0)) + self.stack.add_module(f'conv{i + 1}', SeparableConv2d( + in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type, + act_layer=separable_act_layer, norm_layer=norm_layer)) + in_chs = out_chs[i] + + self.drop_path = drop_path + + def forward(self, x): + skip = x + x = self.stack(x) + if self.shortcut is not None: + skip = self.shortcut(skip) + if not self.no_skip: + if self.drop_path is not None: + x = self.drop_path(x) + x = x + skip + return x + + +class PreXceptionModule(nn.Module): + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 1, + dilation: int = 1, + pad_type: PadType = '', + no_skip: bool = False, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Optional[Type[nn.Module]] = None, + drop_path: Optional[nn.Module] = None + ): + super(PreXceptionModule, self).__init__() + out_chs = to_3tuple(out_chs) + self.in_channels = in_chs + self.out_channels = out_chs[-1] + self.no_skip = no_skip + if not no_skip and (self.out_channels != self.in_channels or stride != 1): + self.shortcut = create_conv2d(in_chs, self.out_channels, 1, stride=stride) + else: + self.shortcut = nn.Identity() + + self.norm = get_norm_act_layer(norm_layer, act_layer=act_layer)(in_chs, inplace=True) + self.stack = nn.Sequential() + for i in range(3): + self.stack.add_module(f'conv{i + 1}', PreSeparableConv2d( + in_chs, + out_chs[i], + 3, + stride=stride if i == 2 else 1, + dilation=dilation, + padding=pad_type, + act_layer=act_layer, + norm_layer=norm_layer, + first_act=i > 0, + )) + in_chs = out_chs[i] + + self.drop_path = drop_path + + def forward(self, x): + x = self.norm(x) + skip = x + x = self.stack(x) + if not self.no_skip: + if self.drop_path is not None: + x = self.drop_path(x) + x = x + self.shortcut(skip) + return x + + +class XceptionAligned(nn.Module): + """Modified Aligned Xception + """ + + def __init__( + self, + block_cfg: List[Dict], + num_classes: int = 1000, + in_chans: int = 3, + output_stride: int = 32, + preact: bool = False, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + drop_rate: float = 0., + drop_path_rate: float = 0., + global_pool: str = 'avg', + ): + super(XceptionAligned, self).__init__() + assert output_stride in (8, 16, 32) + self.num_classes = num_classes + self.drop_rate = drop_rate + self.grad_checkpointing = False + + layer_args = dict(act_layer=act_layer, norm_layer=norm_layer) + self.stem = nn.Sequential(*[ + ConvNormAct(in_chans, 32, kernel_size=3, stride=2, **layer_args), + create_conv2d(32, 64, kernel_size=3, stride=1) if preact else + ConvNormAct(32, 64, kernel_size=3, stride=1, **layer_args) + ]) + + curr_dilation = 1 + curr_stride = 2 + self.feature_info = [] + self.blocks = nn.Sequential() + module_fn = PreXceptionModule if preact else XceptionModule + net_num_blocks = len(block_cfg) + net_block_idx = 0 + for i, b in enumerate(block_cfg): + block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule + b['drop_path'] = DropPath(block_dpr) if block_dpr > 0. else None + b['dilation'] = curr_dilation + if b['stride'] > 1: + name = f'blocks.{i}.stack.conv2' if preact else f'blocks.{i}.stack.act3' + self.feature_info += [dict(num_chs=to_3tuple(b['out_chs'])[-2], reduction=curr_stride, module=name)] + next_stride = curr_stride * b['stride'] + if next_stride > output_stride: + curr_dilation *= b['stride'] + b['stride'] = 1 + else: + curr_stride = next_stride + self.blocks.add_module(str(i), module_fn(**b, **layer_args)) + self.num_features = self.blocks[-1].out_channels + net_block_idx += 1 + + self.feature_info += [dict( + num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] + self.act = act_layer(inplace=True) if preact else nn.Identity() + self.head = ClassifierHead( + in_features=self.num_features, + num_classes=num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + ) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^blocks\.(\d+)', + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head.reset(num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + x = self.act(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _xception(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + XceptionAligned, + variant, + pretrained, + feature_cfg=dict(flatten_sequential=True, feature_cls='hook'), + **kwargs, + ) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (10, 10), + 'crop_pct': 0.903, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'stem.0.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'xception65.ra3_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.94, + ), + + 'xception41.tf_in1k': _cfg(hf_hub_id='timm/'), + 'xception65.tf_in1k': _cfg(hf_hub_id='timm/'), + 'xception71.tf_in1k': _cfg(hf_hub_id='timm/'), + + 'xception41p.ra3_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.94, + ), + 'xception65p.ra3_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.94, + ), +}) + + +@register_model +def xception41(pretrained=False, **kwargs) -> XceptionAligned: + """ Modified Aligned Xception-41 + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 8), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), + ] + model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1)) + return _xception('xception41', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def xception65(pretrained=False, **kwargs) -> XceptionAligned: + """ Modified Aligned Xception-65 + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 16), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), + ] + model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1)) + return _xception('xception65', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def xception71(pretrained=False, **kwargs) -> XceptionAligned: + """ Modified Aligned Xception-71 + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=1), + dict(in_chs=256, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=1), + dict(in_chs=728, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 16), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), + ] + model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1)) + return _xception('xception71', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def xception41p(pretrained=False, **kwargs) -> XceptionAligned: + """ Modified Aligned Xception-41 w/ Pre-Act + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 8), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), no_skip=True, stride=1), + ] + model_args = dict(block_cfg=block_cfg, preact=True, norm_layer=nn.BatchNorm2d) + return _xception('xception41p', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def xception65p(pretrained=False, **kwargs) -> XceptionAligned: + """ Modified Aligned Xception-65 w/ Pre-Act + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 16), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True), + ] + model_args = dict( + block_cfg=block_cfg, preact=True, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1)) + return _xception('xception65p', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/__init__.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2eb67bab53c97eaca1204d1d359d1b72eff0415 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/__init__.py @@ -0,0 +1,17 @@ +from .adabelief import AdaBelief +from .adafactor import Adafactor +from .adahessian import Adahessian +from .adamp import AdamP +from .adamw import AdamW +from .adan import Adan +from .lamb import Lamb +from .lars import Lars +from .lookahead import Lookahead +from .madgrad import MADGRAD +from .nadam import Nadam +from .nvnovograd import NvNovoGrad +from .radam import RAdam +from .rmsprop_tf import RMSpropTF +from .sgdp import SGDP +from .lion import Lion +from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adabelief.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adabelief.py new file mode 100644 index 0000000000000000000000000000000000000000..951d715cc0b605df2f7313c95840b7784c4d0a70 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adabelief.py @@ -0,0 +1,201 @@ +import math +import torch +from torch.optim.optimizer import Optimizer + + +class AdaBelief(Optimizer): + r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-16) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + decoupled_decay (boolean, optional): (default: True) If set as True, then + the optimizer uses decoupled weight decay as in AdamW + fixed_decay (boolean, optional): (default: False) This is used when weight_decouple + is set as True. + When fixed_decay == True, the weight decay is performed as + $W_{new} = W_{old} - W_{old} \times decay$. + When fixed_decay == False, the weight decay is performed as + $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the + weight decay ratio decreases with learning rate (lr). + rectify (boolean, optional): (default: True) If set as True, then perform the rectified + update similar to RAdam + degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update + when variance of gradient is high + reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020 + + For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer' + For example train/args for EfficientNet see these gists + - link to train_scipt: https://gist.github.com/juntang-zhuang/0a501dd51c02278d952cf159bc233037 + - link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3 + """ + + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False, + decoupled_decay=True, fixed_decay=False, rectify=True, degenerated_to_sgd=True): + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): + param['buffer'] = [[None, None, None] for _ in range(10)] + + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, + degenerated_to_sgd=degenerated_to_sgd, decoupled_decay=decoupled_decay, rectify=rectify, + fixed_decay=fixed_decay, buffer=[[None, None, None] for _ in range(10)]) + super(AdaBelief, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdaBelief, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + amsgrad = group['amsgrad'] + + # State initialization + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + + # Exponential moving average of squared gradient values + state['exp_avg_var'] = torch.zeros_like(p) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_var'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError( + 'AdaBelief does not support sparse gradients, please consider SparseAdam instead') + + p_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_fp32 = p_fp32.float() + + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + state = self.state[p] + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p_fp32) + # Exponential moving average of squared gradient values + state['exp_avg_var'] = torch.zeros_like(p_fp32) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_var'] = torch.zeros_like(p_fp32) + + # perform weight decay, check if decoupled weight decay + if group['decoupled_decay']: + if not group['fixed_decay']: + p_fp32.mul_(1.0 - group['lr'] * group['weight_decay']) + else: + p_fp32.mul_(1.0 - group['weight_decay']) + else: + if group['weight_decay'] != 0: + grad.add_(p_fp32, alpha=group['weight_decay']) + + # get current state variable + exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Update first and second moment running average + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + grad_residual = grad - exp_avg + exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2) + + if amsgrad: + max_exp_avg_var = state['max_exp_avg_var'] + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_var, exp_avg_var.add_(group['eps']), out=max_exp_avg_var) + + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + # update + if not group['rectify']: + # Default update + step_size = group['lr'] / bias_correction1 + p_fp32.addcdiv_(exp_avg, denom, value=-step_size) + else: + # Rectified update, forked from RAdam + buffered = group['buffer'][int(state['step'] % 10)] + if state['step'] == buffered[0]: + num_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + num_sma_max = 2 / (1 - beta2) - 1 + num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = num_sma + + # more conservative since it's an approximated value + if num_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) * + (num_sma - 4) / (num_sma_max - 4) * + (num_sma - 2) / num_sma * + num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step']) + elif group['degenerated_to_sgd']: + step_size = 1.0 / (1 - beta1 ** state['step']) + else: + step_size = -1 + buffered[2] = step_size + + if num_sma >= 5: + denom = exp_avg_var.sqrt().add_(group['eps']) + p_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) + elif step_size > 0: + p_fp32.add_(exp_avg, alpha=-step_size * group['lr']) + + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_fp32) + + return loss diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adafactor.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adafactor.py new file mode 100644 index 0000000000000000000000000000000000000000..06057433a9bffa555bdc13b27a1c56cff26acf15 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adafactor.py @@ -0,0 +1,167 @@ +""" Adafactor Optimizer + +Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + +Original header/copyright below. + +""" +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +import math + + +class Adafactor(torch.optim.Optimizer): + """Implements Adafactor algorithm. + This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` + (see https://arxiv.org/abs/1804.04235) + + Note that this optimizer internally adjusts the learning rate depending on the + *scale_parameter*, *relative_step* and *warmup_init* options. + + To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): external learning rate (default: None) + eps (tuple[float, float]): regularization constants for square gradient + and parameter scale respectively (default: (1e-30, 1e-3)) + clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0) + decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8) + beta1 (float): coefficient used for computing running averages of gradient (default: None) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True) + warmup_init (bool): time-dependent learning rate computation depends on + whether warm-up initialization is being used (default: False) + """ + + def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0, + decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False): + relative_step = not lr + if warmup_init and not relative_step: + raise ValueError('warmup_init requires relative_step=True') + + beta1 = None if betas is None else betas[0] # make it compat with standard betas arg + defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate, + beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, + relative_step=relative_step, warmup_init=warmup_init) + super(Adafactor, self).__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + if param_group['relative_step']: + min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2 + lr_t = min(min_step, 1.0 / math.sqrt(param_state['step'])) + param_scale = 1.0 + if param_group['scale_parameter']: + param_scale = max(param_group['eps_scale'], param_state['RMS']) + param_group['lr'] = lr_t * param_scale + return param_group['lr'] + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group['beta1'] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError('Adafactor does not support sparse gradients.') + + state = self.state[p] + + factored, use_first_moment = self._get_options(group, grad.shape) + # State Initialization + if len(state) == 0: + state['step'] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(grad) + if factored: + state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad) + state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad) + else: + state['exp_avg_sq'] = torch.zeros_like(grad) + + state['RMS'] = 0 + else: + if use_first_moment: + state['exp_avg'] = state['exp_avg'].to(grad) + if factored: + state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) + state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) + else: + state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) + + p_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_fp32 = p_fp32.float() + + state['step'] += 1 + state['RMS'] = self._rms(p_fp32) + lr_t = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) + update = grad ** 2 + group['eps'] + if factored: + exp_avg_sq_row = state['exp_avg_sq_row'] + exp_avg_sq_col = state['exp_avg_sq_col'] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state['exp_avg_sq'] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0)) + update.mul_(lr_t) + + if use_first_moment: + exp_avg = state['exp_avg'] + exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) + update = exp_avg + + if group['weight_decay'] != 0: + p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t) + + p_fp32.add_(-update) + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_fp32) + + return loss diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adamp.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adamp.py new file mode 100644 index 0000000000000000000000000000000000000000..ee187633ab745dbb0344dcdc3dcb1cf40e6ae5e9 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adamp.py @@ -0,0 +1,105 @@ +""" +AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py + +Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 +Code: https://github.com/clovaai/AdamP + +Copyright (c) 2020-present NAVER Corp. +MIT license +""" + +import torch +import torch.nn.functional as F +from torch.optim.optimizer import Optimizer +import math + + +def _channel_view(x) -> torch.Tensor: + return x.reshape(x.size(0), -1) + + +def _layer_view(x) -> torch.Tensor: + return x.reshape(1, -1) + + +def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float): + wd = 1. + expand_size = (-1,) + (1,) * (len(p.shape) - 1) + for view_func in [_channel_view, _layer_view]: + param_view = view_func(p) + grad_view = view_func(grad) + cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_() + + # FIXME this is a problem for PyTorch XLA + if cosine_sim.max() < delta / math.sqrt(param_view.size(1)): + p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size) + perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size) + wd = wd_ratio + return perturb, wd + + return perturb, wd + + +class AdamP(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, + delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) + super(AdamP, self).__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + beta1, beta2 = group['betas'] + nesterov = group['nesterov'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + + # Adam + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + step_size = group['lr'] / bias_correction1 + + if nesterov: + perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom + else: + perturb = exp_avg / denom + + # Projection + wd_ratio = 1. + if len(p.shape) > 1: + perturb, wd_ratio = projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps']) + + # Weight decay + if group['weight_decay'] > 0: + p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio) + + # Step + p.add_(perturb, alpha=-step_size) + + return loss diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adamw.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..66478bc6ef3c50ab9d40cabb0cfb2bd24277c815 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adamw.py @@ -0,0 +1,122 @@ +""" AdamW Optimizer +Impl copied from PyTorch master + +NOTE: Builtin optim.AdamW is used by the factory, this impl only serves as a Python based reference, will be removed +someday +""" +import math +import torch +from torch.optim.optimizer import Optimizer + + +class AdamW(Optimizer): + r"""Implements AdamW algorithm. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + step_size = group['lr'] / bias_correction1 + + p.addcdiv_(exp_avg, denom, value=-step_size) + + return loss diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adan.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adan.py new file mode 100644 index 0000000000000000000000000000000000000000..1d2a7585e497f0538229655ec893dbb1f5d4301f --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adan.py @@ -0,0 +1,124 @@ +""" Adan Optimizer + +Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022. + https://arxiv.org/abs/2208.06677 + +Implementation adapted from https://github.com/sail-sg/Adan +""" + +import math + +import torch + +from torch.optim import Optimizer + + +class Adan(Optimizer): + """ + Implements a pytorch variant of Adan + Adan was proposed in + Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022. + https://arxiv.org/abs/2208.06677 + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float, flot], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.98, 0.92, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): decoupled weight decay (L2 penalty) (default: 0) + no_prox (bool): how to perform the decoupled weight decay (default: False) + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.98, 0.92, 0.99), + eps=1e-8, + weight_decay=0.0, + no_prox=False, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, no_prox=no_prox) + super(Adan, self).__init__(params, defaults) + + @torch.no_grad() + def restart_opt(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + if p.requires_grad: + state = self.state[p] + # State initialization + + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + # Exponential moving average of gradient difference + state['exp_avg_diff'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure=None): + """ Performs a single optimization step. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + beta1, beta2, beta3 = group['betas'] + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + bias_correction1 = 1.0 - beta1 ** group['step'] + bias_correction2 = 1.0 - beta2 ** group['step'] + bias_correction3 = 1.0 - beta3 ** group['step'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + + state = self.state[p] + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_diff'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + state['pre_grad'] = grad.clone() + + exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_sq'] + grad_diff = grad - state['pre_grad'] + + exp_avg.lerp_(grad, 1. - beta1) # m_t + exp_avg_diff.lerp_(grad_diff, 1. - beta2) # diff_t (v) + update = grad + beta2 * grad_diff + exp_avg_sq.mul_(beta3).addcmul_(update, update, value=1. - beta3) # n_t + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction3)).add_(group['eps']) + update = (exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2).div_(denom) + if group['no_prox']: + p.data.mul_(1 - group['lr'] * group['weight_decay']) + p.add_(update, alpha=-group['lr']) + else: + p.add_(update, alpha=-group['lr']) + p.data.div_(1 + group['lr'] * group['weight_decay']) + + state['pre_grad'].copy_(grad) + + return loss diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/lamb.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..12c7c49b8a01ef793c97654ac938259ca6508449 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/lamb.py @@ -0,0 +1,192 @@ +""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb + +This optimizer code was adapted from the following (starting with latest) +* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py +* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py +* https://github.com/cybertronai/pytorch-lamb + +Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is +similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX. + +In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. + +Original copyrights for above sources are below. + +Modifications Copyright 2021 Ross Wightman +""" +# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. + +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2019 cybertronai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import math + +import torch +from torch.optim import Optimizer + + +class Lamb(Optimizer): + """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB + reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py + + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) + trust_clip (bool): enable LAMBC trust ratio clipping (default: False) + always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, + weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False): + defaults = dict( + lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, + grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, + trust_clip=trust_clip, always_adapt=always_adapt) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]['params'][0].device + one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly + global_grad_norm = torch.zeros(1, device=device) + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes + # scalar types properly https://github.com/pytorch/pytorch/issues/9190 + max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) + clip_global_grad_norm = torch.where( + global_grad_norm > max_grad_norm, + global_grad_norm / max_grad_norm, + one_tensor) + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + grad_averaging = 1 if group['grad_averaging'] else 0 + beta3 = 1 - beta1 if grad_averaging else 1.0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + if bias_correction: + bias_correction1 = 1 - beta1 ** group['step'] + bias_correction2 = 1 - beta2 ** group['step'] + else: + bias_correction1, bias_correction2 = 1.0, 1.0 + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.div_(clip_global_grad_norm) + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient valuesa + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + update = (exp_avg / bias_correction1).div_(denom) + + weight_decay = group['weight_decay'] + if weight_decay != 0: + update.add_(p, alpha=weight_decay) + + if weight_decay != 0 or group['always_adapt']: + # Layer-wise LR adaptation. By default, skip adaptation on parameters that are + # excluded from weight decay, unless always_adapt == True, then always enabled. + w_norm = p.norm(2.0) + g_norm = update.norm(2.0) + # FIXME nested where required since logical and/or not working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, w_norm / g_norm, one_tensor), + one_tensor, + ) + if group['trust_clip']: + # LAMBC trust clipping, upper bound fixed at one + trust_ratio = torch.minimum(trust_ratio, one_tensor) + update.mul_(trust_ratio) + + p.add_(update, alpha=-group['lr']) + + return loss diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/lars.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/lars.py new file mode 100644 index 0000000000000000000000000000000000000000..38ca9e0b5cb90855104ce7b5ff358cb7fa343f12 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/lars.py @@ -0,0 +1,135 @@ +""" PyTorch LARS / LARC Optimizer + +An implementation of LARS (SGD) + LARC in PyTorch + +Based on: + * PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100 + * NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py + +Additional cleanup and modifications to properly support PyTorch XLA. + +Copyright 2021 Ross Wightman +""" +import torch +from torch.optim.optimizer import Optimizer + + +class Lars(Optimizer): + """ LARS for PyTorch + + Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf + + Args: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): learning rate (default: 1.0). + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001) + eps (float): eps for division denominator (default: 1e-8) + trust_clip (bool): enable LARC trust ratio clipping (default: False) + always_adapt (bool): always apply LARS LR adapt, otherwise only when group weight_decay != 0 (default: False) + """ + + def __init__( + self, + params, + lr=1.0, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + trust_coeff=0.001, + eps=1e-8, + trust_clip=False, + always_adapt=False, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + trust_coeff=trust_coeff, + eps=eps, + trust_clip=trust_clip, + always_adapt=always_adapt, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]['params'][0].device + one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + trust_coeff = group['trust_coeff'] + eps = group['eps'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + + # apply LARS LR adaptation, LARC clipping, weight decay + # ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py + if weight_decay != 0 or group['always_adapt']: + w_norm = p.norm(2.0) + g_norm = grad.norm(2.0) + trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps) + # FIXME nested where required since logical and/or not working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, trust_ratio, one_tensor), + one_tensor, + ) + if group['trust_clip']: + trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor) + grad.add_(p, alpha=weight_decay) + grad.mul_(trust_ratio) + + # apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100 + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = torch.clone(grad).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(grad, alpha=1. - dampening) + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + p.add_(grad, alpha=-group['lr']) + + return loss \ No newline at end of file diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/lion.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/lion.py new file mode 100644 index 0000000000000000000000000000000000000000..4d8086424d190e6e36234200ec159c5d5718c335 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/lion.py @@ -0,0 +1,226 @@ +""" Lion Optimizer +Paper: `Symbolic Discovery of Optimization Algorithms` - https://arxiv.org/abs/2302.06675 +Original Impl: https://github.com/google/automl/tree/master/lion +""" +# Copyright 2023 Google Research. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import List + +import torch +from torch.optim.optimizer import Optimizer + + +class Lion(Optimizer): + r"""Implements Lion algorithm.""" + + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0.0, + maximize=False, + foreach=None, + ): + """Initialize the hyperparameters. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-4) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.99)) + weight_decay (float, optional): weight decay coefficient (default: 0) + """ + + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) + defaults = dict( + lr=lr, + betas=betas, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('maximize', False) + group.setdefault('foreach', None) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + Returns: + the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('Lion does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + + lion( + params_with_grad, + grads, + exp_avgs, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + maximize=group['maximize'], + foreach=group['foreach'], + ) + + return loss + + +def lion( + params: List[torch.Tensor], + grads: List[torch.Tensor], + exp_avgs: List[torch.Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + maximize: bool = False, + foreach: bool = None, + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, +): + r"""Functional API that performs Lion algorithm computation. + """ + if foreach is None: + # Placeholder for more complex foreach logic to be added when value is not set + foreach = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError('torch.jit.script not supported with foreach optimizers') + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_lion + else: + func = _single_tensor_lion + + func( + params, + grads, + exp_avgs, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + maximize=maximize, + ) + + +def _single_tensor_lion( + params: List[torch.Tensor], + grads: List[torch.Tensor], + exp_avgs: List[torch.Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + maximize: bool, +): + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + param = torch.view_as_real(param) + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Weight update + update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1) + param.add_(torch.sign(update), alpha=-lr) + + # Decay the momentum running average coefficient + exp_avg.lerp_(grad, 1 - beta2) + + +def _multi_tensor_lion( + params: List[torch.Tensor], + grads: List[torch.Tensor], + exp_avgs: List[torch.Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + maximize: bool, +): + if len(params) == 0: + return + + if maximize: + grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment] + + grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads] + exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs] + params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params] + + # Perform stepweight decay + torch._foreach_mul_(params, 1 - lr * weight_decay) + + # Weight update + updates = torch._foreach_mul(exp_avgs, beta1) + torch._foreach_add_(updates, grads, alpha=1 - beta1) + + updates = [u.sign() for u in updates] + torch._foreach_add_(params, updates, alpha=-lr) + + # Decay the momentum running average coefficient + torch._foreach_mul_(exp_avgs, beta2) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta2) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/lookahead.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/lookahead.py new file mode 100644 index 0000000000000000000000000000000000000000..1c0f1c91d8b69f329205de68e7d8e22126c9f0e0 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/lookahead.py @@ -0,0 +1,66 @@ +""" Lookahead Optimizer Wrapper. +Implementation modified from: https://github.com/alphadl/lookahead.pytorch +Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 + +Hacked together by / Copyright 2020 Ross Wightman +""" +from collections import OrderedDict +from typing import Callable, Dict + +import torch +from torch.optim.optimizer import Optimizer +from collections import defaultdict + + +class Lookahead(Optimizer): + def __init__(self, base_optimizer, alpha=0.5, k=6): + # NOTE super().__init__() not called on purpose + self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict() + self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict() + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) + self._base_optimizer = base_optimizer + self.param_groups = base_optimizer.param_groups + self.defaults = base_optimizer.defaults + self.defaults.update(defaults) + self.state = defaultdict(dict) + # manually add our defaults to the param groups + for name, default in defaults.items(): + for group in self._base_optimizer.param_groups: + group.setdefault(name, default) + + @torch.no_grad() + def update_slow(self, group): + for fast_p in group["params"]: + if fast_p.grad is None: + continue + param_state = self._base_optimizer.state[fast_p] + if 'lookahead_slow_buff' not in param_state: + param_state['lookahead_slow_buff'] = torch.empty_like(fast_p) + param_state['lookahead_slow_buff'].copy_(fast_p) + slow = param_state['lookahead_slow_buff'] + slow.add_(fast_p - slow, alpha=group['lookahead_alpha']) + fast_p.copy_(slow) + + def sync_lookahead(self): + for group in self._base_optimizer.param_groups: + self.update_slow(group) + + @torch.no_grad() + def step(self, closure=None): + loss = self._base_optimizer.step(closure) + for group in self._base_optimizer.param_groups: + group['lookahead_step'] += 1 + if group['lookahead_step'] % group['lookahead_k'] == 0: + self.update_slow(group) + return loss + + def state_dict(self): + return self._base_optimizer.state_dict() + + def load_state_dict(self, state_dict): + self._base_optimizer.load_state_dict(state_dict) + self.param_groups = self._base_optimizer.param_groups diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/madgrad.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/madgrad.py new file mode 100644 index 0000000000000000000000000000000000000000..a76713bf27ed1daf0ce598ac5f25c6238c7fdb57 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/madgrad.py @@ -0,0 +1,184 @@ +""" PyTorch MADGRAD optimizer + +MADGRAD: https://arxiv.org/abs/2101.11075 + +Code from: https://github.com/facebookresearch/madgrad +""" +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import TYPE_CHECKING, Any, Callable, Optional + +import torch +import torch.optim + +if TYPE_CHECKING: + from torch.optim.optimizer import _params_t +else: + _params_t = Any + + +class MADGRAD(torch.optim.Optimizer): + """ + MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic + Optimization. + + .. _MADGRAD: https://arxiv.org/abs/2101.11075 + + MADGRAD is a general purpose optimizer that can be used in place of SGD or + Adam may converge faster and generalize better. Currently GPU-only. + Typically, the same learning rate schedule that is used for SGD or Adam may + be used. The overall learning rate is not comparable to either method and + should be determined by a hyper-parameter sweep. + + MADGRAD requires less weight decay than other methods, often as little as + zero. Momentum values used for SGD or Adam's beta1 should work here also. + + On sparse problems both weight_decay and momentum should be set to 0. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate (default: 1e-2). + momentum (float): + Momentum value in the range [0,1) (default: 0.9). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + eps (float): + Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6). + """ + + def __init__( + self, + params: _params_t, + lr: float = 1e-2, + momentum: float = 0.9, + weight_decay: float = 0, + eps: float = 1e-6, + decoupled_decay: bool = False, + ): + if momentum < 0 or momentum >= 1: + raise ValueError(f"Momentum {momentum} must be in the range [0,1]") + if lr <= 0: + raise ValueError(f"Learning rate {lr} must be positive") + if weight_decay < 0: + raise ValueError(f"Weight decay {weight_decay} must be non-negative") + if eps < 0: + raise ValueError(f"Eps must be non-negative") + + defaults = dict( + lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay) + super().__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self) -> bool: + return False + + @property + def supports_flat_params(self) -> bool: + return True + + @torch.no_grad() + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + eps = group['eps'] + lr = group['lr'] + eps + weight_decay = group['weight_decay'] + momentum = group['momentum'] + ck = 1 - momentum + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if momentum != 0.0 and grad.is_sparse: + raise RuntimeError("momentum != 0 is not compatible with sparse gradients") + + state = self.state[p] + if len(state) == 0: + state['step'] = 0 + state['grad_sum_sq'] = torch.zeros_like(p) + state['s'] = torch.zeros_like(p) + if momentum != 0: + state['x0'] = torch.clone(p).detach() + + state['step'] += 1 + grad_sum_sq = state['grad_sum_sq'] + s = state['s'] + lamb = lr * math.sqrt(state['step']) + + # Apply weight decay + if weight_decay != 0: + if group['decoupled_decay']: + p.mul_(1.0 - group['lr'] * weight_decay) + else: + if grad.is_sparse: + raise RuntimeError("weight_decay option is not compatible with sparse gradients") + grad.add_(p, alpha=weight_decay) + + if grad.is_sparse: + grad = grad.coalesce() + grad_val = grad._values() + + p_masked = p.sparse_mask(grad) + grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad) + s_masked = s.sparse_mask(grad) + + # Compute x_0 from other known quantities + rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps) + x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1) + + # Dense + sparse op + grad_sq = grad * grad + grad_sum_sq.add_(grad_sq, alpha=lamb) + grad_sum_sq_masked.add_(grad_sq, alpha=lamb) + + rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps) + + s.add_(grad, alpha=lamb) + s_masked._values().add_(grad_val, alpha=lamb) + + # update masked copy of p + p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1) + # Copy updated masked p to dense p using an add operation + p_masked._values().add_(p_kp1_masked_vals, alpha=-1) + p.add_(p_masked, alpha=-1) + else: + if momentum == 0: + # Compute x_0 from other known quantities + rms = grad_sum_sq.pow(1 / 3).add_(eps) + x0 = p.addcdiv(s, rms, value=1) + else: + x0 = state['x0'] + + # Accumulate second moments + grad_sum_sq.addcmul_(grad, grad, value=lamb) + rms = grad_sum_sq.pow(1 / 3).add_(eps) + + # Update s + s.add_(grad, alpha=lamb) + + # Step + if momentum == 0: + p.copy_(x0.addcdiv(s, rms, value=-1)) + else: + z = x0.addcdiv(s, rms, value=-1) + + # p is a moving average of z + p.mul_(1 - ck).add_(z, alpha=ck) + + return loss diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/nadamw.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/nadamw.py new file mode 100644 index 0000000000000000000000000000000000000000..c823f3d5b229ff135ccfbfb97ff99ded61c8ab4b --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/nadamw.py @@ -0,0 +1,349 @@ +""" NAdamW Optimizer + +Based on simplified algorithm in https://github.com/mlcommons/algorithmic-efficiency/tree/main/baselines/nadamw + +Added multi-tensor (foreach) path. +""" +import math +from typing import List, Optional + +import torch +from torch import Tensor + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + maximize: bool = False, + foreach: Optional[bool] = None, + capturable: bool = False, + ): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + capturable=capturable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=group['maximize'], + capturable=group['capturable'], + ) + + return loss + + +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + foreach: Optional[bool] = None, + capturable: bool = False, + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, +) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + if foreach is None: + foreach = True + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_nadamw + else: + func = _single_tensor_nadamw + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + ) + + +def _single_tensor_nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool +): + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1. - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + if capturable: + step = step_t + + # 1 - beta1 ** step can't be captured in a CUDA graph, even if step is a CUDA tensor + # (incurs "RuntimeError: CUDA error: operation not permitted when stream is capturing") + bias_correction1 = 1 - torch.pow(beta1, step) + bias_correction2 = 1 - torch.pow(beta2, step) + + step_size = lr / bias_correction1 + step_size_neg = step_size.neg() + + bias_correction2_sqrt = bias_correction2.sqrt() + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1) + + denom = (exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg) + param.addcdiv_(exp_avg, denom) + else: + step = step_t.item() + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + step_size = lr / bias_correction1 + bias_correction2_sqrt = math.sqrt(bias_correction2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1) + + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + param.addcdiv_(exp_avg, denom, value=-step_size) + + +def _multi_tensor_nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, +): + if len(params) == 0: + return + + if capturable: + assert all( + p.is_cuda and step.is_cuda for p, step in zip(params, state_steps) + ), "If capturable=True, params and state_steps must be CUDA tensors." + + if maximize: + grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment] + + grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads] + exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs] + exp_avg_sqs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avg_sqs] + params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params] + + # update steps + torch._foreach_add_(state_steps, 1) + + # Perform stepweight decay + torch._foreach_mul_(params, 1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + torch._foreach_mul_(exp_avgs, beta1) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) + + torch._foreach_mul_(exp_avg_sqs, beta2) + torch._foreach_addcmul_(exp_avg_sqs, grads, grads, 1 - beta2) + + if capturable: + # TODO: use foreach_pow if/when foreach_pow is added + bias_correction1 = [torch.pow(beta1, step) for step in state_steps] + bias_correction2 = [torch.pow(beta2, step) for step in state_steps] + # foreach_sub doesn't allow a scalar as the first arg + torch._foreach_sub_(bias_correction1, 1) + torch._foreach_sub_(bias_correction2, 1) + torch._foreach_neg_(bias_correction1) + torch._foreach_neg_(bias_correction2) + + # foreach_div doesn't allow a scalar as the first arg + step_size = torch._foreach_div(bias_correction1, lr) + torch._foreach_reciprocal_(step_size) + torch._foreach_neg_(step_size) + + bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + exp_avgs = torch._foreach_mul(exp_avgs, beta1) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) + + exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs) + torch._foreach_div_( + exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size) + ) + eps_over_step_size = torch._foreach_div(step_size, eps) + torch._foreach_reciprocal_(eps_over_step_size) + denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size) + + torch._foreach_addcdiv_(params, exp_avgs, denom) + else: + bias_correction1 = [1 - beta1 ** step.item() for step in state_steps] + bias_correction2 = [1 - beta2 ** step.item() for step in state_steps] + + step_size = [(lr / bc) * -1 for bc in bias_correction1] + + bias_correction2_sqrt = [math.sqrt(bc) for bc in bias_correction2] + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + exp_avgs = torch._foreach_mul(exp_avgs, beta1) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) + + exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs) + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) + denom = torch._foreach_add(exp_avg_sq_sqrt, eps) + + torch._foreach_addcdiv_(params, exp_avgs, denom, step_size) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/nvnovograd.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/nvnovograd.py new file mode 100644 index 0000000000000000000000000000000000000000..fda3f4a620fcca5593034dfb9683f2c8f3b78ac1 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/nvnovograd.py @@ -0,0 +1,120 @@ +""" Nvidia NovoGrad Optimizer. +Original impl by Nvidia from Jasper example: + - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper +Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` + - https://arxiv.org/abs/1905.11286 +""" + +import torch +from torch.optim.optimizer import Optimizer +import math + + +class NvNovoGrad(Optimizer): + """ + Implements Novograd algorithm. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.95, 0.98)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging: gradient averaging + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + """ + + def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, + weight_decay=0, grad_averaging=False, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + amsgrad=amsgrad) + + super(NvNovoGrad, self).__init__(params, defaults) + + def __setstate__(self, state): + super(NvNovoGrad, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Sparse gradients are not supported.') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + norm = torch.sum(torch.pow(grad, 2)) + + if exp_avg_sq == 0: + exp_avg_sq.copy_(norm) + else: + exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2) + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + grad.div_(denom) + if group['weight_decay'] != 0: + grad.add_(p, alpha=group['weight_decay']) + if group['grad_averaging']: + grad.mul_(1 - beta1) + exp_avg.mul_(beta1).add_(grad) + + p.add_(exp_avg, alpha=-group['lr']) + + return loss diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/radam.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/radam.py new file mode 100644 index 0000000000000000000000000000000000000000..eb8d22e06c42e487c831297008851b4adc254d78 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/radam.py @@ -0,0 +1,89 @@ +"""RAdam Optimizer. +Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam +Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 +""" +import math +import torch +from torch.optim.optimizer import Optimizer + + +class RAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, + buffer=[[None, None, None] for _ in range(10)]) + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_fp32 = p.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + state['step'] += 1 + buffered = group['buffer'][int(state['step'] % 10)] + if state['step'] == buffered[0]: + num_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + num_sma_max = 2 / (1 - beta2) - 1 + num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = num_sma + + # more conservative since it's an approximated value + if num_sma >= 5: + step_size = group['lr'] * math.sqrt( + (1 - beta2_t) * + (num_sma - 4) / (num_sma_max - 4) * + (num_sma - 2) / num_sma * + num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = group['lr'] / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr']) + + # more conservative since it's an approximated value + if num_sma >= 5: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_fp32.addcdiv_(exp_avg, denom, value=-step_size) + else: + p_fp32.add_(exp_avg, alpha=-step_size) + + p.copy_(p_fp32) + + return loss diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/rmsprop_tf.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/rmsprop_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..0817887db380261dfee3fcd4bd155b5d923f5248 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/rmsprop_tf.py @@ -0,0 +1,139 @@ +""" RMSProp modified to behave like Tensorflow impl + +Originally cut & paste from PyTorch RMSProp +https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py +Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE + +Modifications Copyright 2021 Ross Wightman +""" + +import torch +from torch.optim import Optimizer + + +class RMSpropTF(Optimizer): + """Implements RMSprop algorithm (TensorFlow style epsilon) + + NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt + and a few other modifications to closer match Tensorflow for matching hyper-params. + + Noteworthy changes include: + 1. Epsilon applied inside square-root + 2. square_avg initialized to ones + 3. LR scaling of update accumulated in momentum buffer + + Proposed by G. Hinton in his + `course `_. + + The centered version first appears in `Generating Sequences + With Recurrent Neural Networks `_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + momentum (float, optional): momentum factor (default: 0) + alpha (float, optional): smoothing (decay) constant (default: 0.9) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-10) + centered (bool, optional) : if ``True``, compute the centered RMSProp, + the gradient is normalized by an estimation of its variance + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 + lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer + update as per defaults in Tensorflow + + """ + + def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, + decoupled_decay=False, lr_in_momentum=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= momentum: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= alpha: + raise ValueError("Invalid alpha value: {}".format(alpha)) + + defaults = dict( + lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, + decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) + super(RMSpropTF, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RMSpropTF, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('momentum', 0) + group.setdefault('centered', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('RMSprop does not support sparse gradients') + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['square_avg'] = torch.ones_like(p) # PyTorch inits to zero + if group['momentum'] > 0: + state['momentum_buffer'] = torch.zeros_like(p) + if group['centered']: + state['grad_avg'] = torch.zeros_like(p) + + square_avg = state['square_avg'] + one_minus_alpha = 1. - group['alpha'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + if group['decoupled_decay']: + p.mul_(1. - group['lr'] * group['weight_decay']) + else: + grad = grad.add(p, alpha=group['weight_decay']) + + # Tensorflow order of ops for updating squared avg + square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha) + # square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) # PyTorch original + + if group['centered']: + grad_avg = state['grad_avg'] + grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha) + avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add(group['eps']).sqrt_() # eps in sqrt + # grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) # PyTorch original + else: + avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt + + if group['momentum'] > 0: + buf = state['momentum_buffer'] + # Tensorflow accumulates the LR scaling in the momentum buffer + if group['lr_in_momentum']: + buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr']) + p.add_(-buf) + else: + # PyTorch scales the param update by LR + buf.mul_(group['momentum']).addcdiv_(grad, avg) + p.add_(buf, alpha=-group['lr']) + else: + p.addcdiv_(grad, avg, value=-group['lr']) + + return loss diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/sgdp.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/sgdp.py new file mode 100644 index 0000000000000000000000000000000000000000..baf05fa55c632371498ec53ff679b11023429df6 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/sgdp.py @@ -0,0 +1,70 @@ +""" +SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py + +Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 +Code: https://github.com/clovaai/AdamP + +Copyright (c) 2020-present NAVER Corp. +MIT license +""" + +import torch +import torch.nn.functional as F +from torch.optim.optimizer import Optimizer, required +import math + +from .adamp import projection + + +class SGDP(Optimizer): + def __init__(self, params, lr=required, momentum=0, dampening=0, + weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): + defaults = dict( + lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, + nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) + super(SGDP, self).__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + state = self.state[p] + + # State initialization + if len(state) == 0: + state['momentum'] = torch.zeros_like(p) + + # SGD + buf = state['momentum'] + buf.mul_(momentum).add_(grad, alpha=1. - dampening) + if nesterov: + d_p = grad + momentum * buf + else: + d_p = buf + + # Projection + wd_ratio = 1. + if len(p.shape) > 1: + d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) + + # Weight decay + if weight_decay != 0: + p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) + + # Step + p.add_(d_p, alpha=-group['lr']) + + return loss diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/sgdw.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/sgdw.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d2c12f031fd94658f99dac67a2e9a38ea9d9bd --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/sgdw.py @@ -0,0 +1,266 @@ +from functools import update_wrapper, wraps +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer +try: + from torch.optim.optimizer import _use_grad_for_differentiable, _default_to_fused_or_foreach + has_recent_pt = True +except ImportError: + has_recent_pt = False + +from typing import List, Optional + +__all__ = ['SGDW', 'sgdw'] + + +class SGDW(Optimizer): + def __init__( + self, + params, + lr=1e-3, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + *, + maximize: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, momentum=momentum, dampening=dampening, + weight_decay=weight_decay, nesterov=nesterov, + maximize=maximize, foreach=foreach, + differentiable=differentiable) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + group.setdefault('maximize', False) + group.setdefault('foreach', None) + group.setdefault('differentiable', False) + + def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list): + has_sparse_grad = False + + for p in group['params']: + if p.grad is not None: + params_with_grad.append(p) + d_p_list.append(p.grad) + if p.grad.is_sparse: + has_sparse_grad = True + + state = self.state[p] + if 'momentum_buffer' not in state: + momentum_buffer_list.append(None) + else: + momentum_buffer_list.append(state['momentum_buffer']) + + return has_sparse_grad + + # FIXME figure out how to make _use_grad_for_differentiable interchangeable with no_grad decorator + # without args, for backwards compatibility with old pytorch + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + d_p_list = [] + momentum_buffer_list = [] + + has_sparse_grad = self._init_group(group, params_with_grad, d_p_list, momentum_buffer_list) + + sgdw( + params_with_grad, + d_p_list, + momentum_buffer_list, + weight_decay=group['weight_decay'], + momentum=group['momentum'], + lr=group['lr'], + dampening=group['dampening'], + nesterov=group['nesterov'], + maximize=group['maximize'], + has_sparse_grad=has_sparse_grad, + foreach=group['foreach'], + ) + + # update momentum_buffers in state + for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list): + state = self.state[p] + state['momentum_buffer'] = momentum_buffer + + return loss + + +def sgdw( + params: List[Tensor], + d_p_list: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + has_sparse_grad: bool = None, + foreach: Optional[bool] = None, + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool +): + r"""Functional API that performs SGD algorithm computation. + + See :class:`~torch.optim.SGD` for details. + """ + if has_recent_pt and hasattr(Optimizer, '_group_tensors_by_device_and_dtype'): + if foreach is None: + # why must we be explicit about an if statement for torch.jit.is_scripting here? + # because JIT can't handle Optionals nor fancy conditionals when scripting + if not torch.jit.is_scripting(): + _, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False) + else: + foreach = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError('torch.jit.script not supported with foreach optimizers') + else: + foreach = False # disabling altogether for older pytorch, as using _group_tensors_by_device_and_dtype + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_sgdw + else: + func = _single_tensor_sgdw + + func( + params, + d_p_list, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=nesterov, + has_sparse_grad=has_sparse_grad, + maximize=maximize, + ) + + +def _single_tensor_sgdw( + params: List[Tensor], + d_p_list: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool +): + for i, param in enumerate(params): + d_p = d_p_list[i] if not maximize else -d_p_list[i] + + param.mul_(1. - lr * weight_decay) + + if momentum != 0: + buf = momentum_buffer_list[i] + + if buf is None: + buf = torch.clone(d_p).detach() + momentum_buffer_list[i] = buf + else: + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + + if nesterov: + d_p = d_p.add(buf, alpha=momentum) + else: + d_p = buf + + param.add_(d_p, alpha=-lr) + + +def _multi_tensor_sgdw( + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool +): + if len(params) == 0: + return + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, momentum_buffer_list], with_indices=True) + for ((device_params, device_grads, device_momentum_buffer_list), indices) in grouped_tensors.values(): + device_has_sparse_grad = has_sparse_grad and any(grad.is_sparse for grad in device_grads) + + if maximize: + device_grads = torch._foreach_neg(device_grads) + + torch._foreach_mul_(params, 1. - lr * weight_decay) + + if momentum != 0: + bufs = [] + + all_states_with_momentum_buffer = True + for i in range(len(device_momentum_buffer_list)): + if device_momentum_buffer_list[i] is None: + all_states_with_momentum_buffer = False + break + else: + bufs.append(device_momentum_buffer_list[i]) + + if all_states_with_momentum_buffer: + torch._foreach_mul_(bufs, momentum) + torch._foreach_add_(bufs, device_grads, alpha=1 - dampening) + else: + bufs = [] + for i in range(len(device_momentum_buffer_list)): + if device_momentum_buffer_list[i] is None: + buf = device_momentum_buffer_list[i] = momentum_buffer_list[indices[i]] = \ + torch.clone(device_grads[i]).detach() + else: + buf = device_momentum_buffer_list[i] + buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening) + + bufs.append(buf) + + if nesterov: + torch._foreach_add_(device_grads, bufs, alpha=momentum) + else: + device_grads = bufs + + if not device_has_sparse_grad: + torch._foreach_add_(device_params, device_grads, alpha=-lr) + else: + # foreach APIs don't support sparse + for i in range(len(device_params)): + device_params[i].add_(device_grads[i], alpha=-lr) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/__init__.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7191bb0f1c921a5e214b1414cd07269297db95 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/__init__.py @@ -0,0 +1,8 @@ +from .cosine_lr import CosineLRScheduler +from .multistep_lr import MultiStepLRScheduler +from .plateau_lr import PlateauLRScheduler +from .poly_lr import PolyLRScheduler +from .step_lr import StepLRScheduler +from .tanh_lr import TanhLRScheduler + +from .scheduler_factory import create_scheduler, create_scheduler_v2, scheduler_kwargs diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/cosine_lr.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/cosine_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..e2c975fb790126f7dc7ca0b99a9d489800bdb52e --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/cosine_lr.py @@ -0,0 +1,115 @@ +""" Cosine Scheduler + +Cosine LR schedule with warmup, cycle/restarts, noise, k-decay. + +Hacked together by / Copyright 2021 Ross Wightman +""" +import logging +import math +import numpy as np +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class CosineLRScheduler(Scheduler): + """ + Cosine decay with restarts. + This is described in the paper https://arxiv.org/abs/1608.03983. + + Inspiration from + https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py + + k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909 + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + t_initial: int, + lr_min: float = 0., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + k_decay=1.0, + initialize=True, + ) -> None: + super().__init__( + optimizer, + param_group_field="lr", + t_in_epochs=t_in_epochs, + noise_range_t=noise_range_t, + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, + initialize=initialize, + ) + + assert t_initial > 0 + assert lr_min >= 0 + if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1: + _logger.warning( + "Cosine annealing scheduler will have no effect on the learning " + "rate since t_initial = t_mul = eta_mul = 1.") + self.t_initial = t_initial + self.lr_min = lr_min + self.cycle_mul = cycle_mul + self.cycle_decay = cycle_decay + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.k_decay = k_decay + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.cycle_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) + t_i = self.cycle_mul ** i * self.t_initial + t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + gamma = self.cycle_decay ** i + lr_max_values = [v * gamma for v in self.base_values] + k = self.k_decay + + if i < self.cycle_limit: + lrs = [ + self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k)) + for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min for _ in self.base_values] + + return lrs + + def get_cycle_length(self, cycles=0): + cycles = max(1, cycles or self.cycle_limit) + if self.cycle_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/multistep_lr.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/multistep_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..10f2fb50446e66cf3e9b70e8163b5fa7e024f84a --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/multistep_lr.py @@ -0,0 +1,63 @@ +""" MultiStep LR Scheduler + +Basic multi step LR schedule with warmup, noise. +""" +import torch +import bisect +from timm.scheduler.scheduler import Scheduler +from typing import List + +class MultiStepLRScheduler(Scheduler): + """ + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + decay_t: List[int], + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=True, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: + super().__init__( + optimizer, + param_group_field="lr", + t_in_epochs=t_in_epochs, + noise_range_t=noise_range_t, + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, + initialize=initialize, + ) + + self.decay_t = decay_t + self.decay_rate = decay_rate + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def get_curr_decay_steps(self, t): + # find where in the array t goes, + # assumes self.decay_t is sorted + return bisect.bisect_right(self.decay_t, t + 1) + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values] + return lrs diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/plateau_lr.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/plateau_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..9f8271579bbefaf3e9d0322cbe233af638a7433a --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/plateau_lr.py @@ -0,0 +1,110 @@ +""" Plateau Scheduler + +Adapts PyTorch plateau scheduler and allows application of noise, warmup. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch + +from .scheduler import Scheduler + + +class PlateauLRScheduler(Scheduler): + """Decay the LR by a factor every time the validation loss plateaus.""" + + def __init__( + self, + optimizer, + decay_rate=0.1, + patience_t=10, + verbose=True, + threshold=1e-4, + cooldown_t=0, + warmup_t=0, + warmup_lr_init=0, + lr_min=0, + mode='max', + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, + initialize=True, + ): + super().__init__( + optimizer, + 'lr', + noise_range_t=noise_range_t, + noise_type=noise_type, + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, + initialize=initialize, + ) + + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, + patience=patience_t, + factor=decay_rate, + verbose=verbose, + threshold=threshold, + cooldown=cooldown_t, + mode=mode, + min_lr=lr_min + ) + + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + self.restore_lr = None + + def state_dict(self): + return { + 'best': self.lr_scheduler.best, + 'last_epoch': self.lr_scheduler.last_epoch, + } + + def load_state_dict(self, state_dict): + self.lr_scheduler.best = state_dict['best'] + if 'last_epoch' in state_dict: + self.lr_scheduler.last_epoch = state_dict['last_epoch'] + + # override the base class step fn completely + def step(self, epoch, metric=None): + if epoch <= self.warmup_t: + lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] + super().update_groups(lrs) + else: + if self.restore_lr is not None: + # restore actual LR from before our last noise perturbation before stepping base + for i, param_group in enumerate(self.optimizer.param_groups): + param_group['lr'] = self.restore_lr[i] + self.restore_lr = None + + self.lr_scheduler.step(metric, epoch) # step the base scheduler + + if self._is_apply_noise(epoch): + self._apply_noise(epoch) + + def step_update(self, num_updates: int, metric: float = None): + return None + + def _apply_noise(self, epoch): + noise = self._calculate_noise(epoch) + + # apply the noise on top of previous LR, cache the old value so we can restore for normal + # stepping of base scheduler + restore_lr = [] + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group['lr']) + restore_lr.append(old_lr) + new_lr = old_lr + old_lr * noise + param_group['lr'] = new_lr + self.restore_lr = restore_lr + + def _get_lr(self, t: int) -> float: + assert False, 'should not be called as step is overridden' diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/poly_lr.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/poly_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..906f6acf82a7996741427f2388f097ae8a87259b --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/poly_lr.py @@ -0,0 +1,111 @@ +""" Polynomial Scheduler + +Polynomial LR schedule with warmup, noise. + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math +import logging + +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class PolyLRScheduler(Scheduler): + """ Polynomial LR Scheduler w/ warmup, noise, and k-decay + + k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909 + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + t_initial: int, + power: float = 0.5, + lr_min: float = 0., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + k_decay=1.0, + initialize=True, + ) -> None: + super().__init__( + optimizer, + param_group_field="lr", + t_in_epochs=t_in_epochs, + noise_range_t=noise_range_t, + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, + initialize=initialize + ) + + assert t_initial > 0 + assert lr_min >= 0 + if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1: + _logger.warning("Cosine annealing scheduler will have no effect on the learning " + "rate since t_initial = t_mul = eta_mul = 1.") + self.t_initial = t_initial + self.power = power + self.lr_min = lr_min + self.cycle_mul = cycle_mul + self.cycle_decay = cycle_decay + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.k_decay = k_decay + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.cycle_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) + t_i = self.cycle_mul ** i * self.t_initial + t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + gamma = self.cycle_decay ** i + lr_max_values = [v * gamma for v in self.base_values] + k = self.k_decay + + if i < self.cycle_limit: + lrs = [ + self.lr_min + (lr_max - self.lr_min) * (1 - t_curr ** k / t_i ** k) ** self.power + for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min for _ in self.base_values] + + return lrs + + def get_cycle_length(self, cycles=0): + cycles = max(1, cycles or self.cycle_limit) + if self.cycle_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/scheduler.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..4ae2e2aeb6831022899453f30e2620c0483b050d --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/scheduler.py @@ -0,0 +1,127 @@ +import abc +from abc import ABC +from typing import Any, Dict, Optional + +import torch + + +class Scheduler(ABC): + """ Parameter Scheduler Base Class + A scheduler base class that can be used to schedule any optimizer parameter groups. + + Unlike the builtin PyTorch schedulers, this is intended to be consistently called + * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value + * At the END of each optimizer update, after incrementing the update count, to calculate next update's value + + The schedulers built on this should try to remain as stateless as possible (for simplicity). + + This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' + and -1 values for special behaviour. All epoch and update counts must be tracked in the training + code and explicitly passed in to the schedulers on the corresponding step or step_update call. + + Based on ideas from: + * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler + * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + param_group_field: str, + t_in_epochs: bool = True, + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, + initialize: bool = True, + ) -> None: + self.optimizer = optimizer + self.param_group_field = param_group_field + self._initial_param_group_field = f"initial_{param_group_field}" + if initialize: + for i, group in enumerate(self.optimizer.param_groups): + if param_group_field not in group: + raise KeyError(f"{param_group_field} missing from param_groups[{i}]") + group.setdefault(self._initial_param_group_field, group[param_group_field]) + else: + for i, group in enumerate(self.optimizer.param_groups): + if self._initial_param_group_field not in group: + raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") + self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] + self.metric = None # any point to having this for all? + self.t_in_epochs = t_in_epochs + self.noise_range_t = noise_range_t + self.noise_pct = noise_pct + self.noise_type = noise_type + self.noise_std = noise_std + self.noise_seed = noise_seed if noise_seed is not None else 42 + self.update_groups(self.base_values) + + def state_dict(self) -> Dict[str, Any]: + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.__dict__.update(state_dict) + + @abc.abstractmethod + def _get_lr(self, t: int) -> float: + pass + + def _get_values(self, t: int, on_epoch: bool = True) -> Optional[float]: + proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs) + if not proceed: + return None + return self._get_lr(t) + + def step(self, epoch: int, metric: float = None) -> None: + self.metric = metric + values = self._get_values(epoch, on_epoch=True) + if values is not None: + values = self._add_noise(values, epoch) + self.update_groups(values) + + def step_update(self, num_updates: int, metric: float = None): + self.metric = metric + values = self._get_values(num_updates, on_epoch=False) + if values is not None: + values = self._add_noise(values, num_updates) + self.update_groups(values) + + def update_groups(self, values): + if not isinstance(values, (list, tuple)): + values = [values] * len(self.optimizer.param_groups) + for param_group, value in zip(self.optimizer.param_groups, values): + if 'lr_scale' in param_group: + param_group[self.param_group_field] = value * param_group['lr_scale'] + else: + param_group[self.param_group_field] = value + + def _add_noise(self, lrs, t): + if self._is_apply_noise(t): + noise = self._calculate_noise(t) + lrs = [v + v * noise for v in lrs] + return lrs + + def _is_apply_noise(self, t) -> bool: + """Return True if scheduler in noise range.""" + apply_noise = False + if self.noise_range_t is not None: + if isinstance(self.noise_range_t, (list, tuple)): + apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] + else: + apply_noise = t >= self.noise_range_t + return apply_noise + + def _calculate_noise(self, t) -> float: + g = torch.Generator() + g.manual_seed(self.noise_seed + t) + if self.noise_type == 'normal': + while True: + # resample if noise out of percent limit, brute force but shouldn't spin much + noise = torch.randn(1, generator=g).item() + if abs(noise) < self.noise_pct: + return noise + else: + noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + return noise diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/scheduler_factory.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/scheduler_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..caf68fad4f4a5352227bd1fb53a797912d2e7e61 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/scheduler_factory.py @@ -0,0 +1,206 @@ +""" Scheduler Factory +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import List, Optional, Union + +from torch.optim import Optimizer + +from .cosine_lr import CosineLRScheduler +from .multistep_lr import MultiStepLRScheduler +from .plateau_lr import PlateauLRScheduler +from .poly_lr import PolyLRScheduler +from .step_lr import StepLRScheduler +from .tanh_lr import TanhLRScheduler + + +def scheduler_kwargs(cfg, decreasing_metric: Optional[bool] = None): + """ cfg/argparse to kwargs helper + Convert scheduler args in argparse args or cfg (.dot) like object to keyword args. + """ + eval_metric = getattr(cfg, 'eval_metric', 'top1') + if decreasing_metric is not None: + plateau_mode = 'min' if decreasing_metric else 'max' + else: + plateau_mode = 'min' if 'loss' in eval_metric else 'max' + kwargs = dict( + sched=cfg.sched, + num_epochs=getattr(cfg, 'epochs', 100), + decay_epochs=getattr(cfg, 'decay_epochs', 30), + decay_milestones=getattr(cfg, 'decay_milestones', [30, 60]), + warmup_epochs=getattr(cfg, 'warmup_epochs', 5), + cooldown_epochs=getattr(cfg, 'cooldown_epochs', 0), + patience_epochs=getattr(cfg, 'patience_epochs', 10), + decay_rate=getattr(cfg, 'decay_rate', 0.1), + min_lr=getattr(cfg, 'min_lr', 0.), + warmup_lr=getattr(cfg, 'warmup_lr', 1e-5), + warmup_prefix=getattr(cfg, 'warmup_prefix', False), + noise=getattr(cfg, 'lr_noise', None), + noise_pct=getattr(cfg, 'lr_noise_pct', 0.67), + noise_std=getattr(cfg, 'lr_noise_std', 1.), + noise_seed=getattr(cfg, 'seed', 42), + cycle_mul=getattr(cfg, 'lr_cycle_mul', 1.), + cycle_decay=getattr(cfg, 'lr_cycle_decay', 0.1), + cycle_limit=getattr(cfg, 'lr_cycle_limit', 1), + k_decay=getattr(cfg, 'lr_k_decay', 1.0), + plateau_mode=plateau_mode, + step_on_epochs=not getattr(cfg, 'sched_on_updates', False), + ) + return kwargs + + +def create_scheduler( + args, + optimizer: Optimizer, + updates_per_epoch: int = 0, +): + return create_scheduler_v2( + optimizer=optimizer, + **scheduler_kwargs(args), + updates_per_epoch=updates_per_epoch, + ) + + +def create_scheduler_v2( + optimizer: Optimizer, + sched: str = 'cosine', + num_epochs: int = 300, + decay_epochs: int = 90, + decay_milestones: List[int] = (90, 180, 270), + cooldown_epochs: int = 0, + patience_epochs: int = 10, + decay_rate: float = 0.1, + min_lr: float = 0, + warmup_lr: float = 1e-5, + warmup_epochs: int = 0, + warmup_prefix: bool = False, + noise: Union[float, List[float]] = None, + noise_pct: float = 0.67, + noise_std: float = 1., + noise_seed: int = 42, + cycle_mul: float = 1., + cycle_decay: float = 0.1, + cycle_limit: int = 1, + k_decay: float = 1.0, + plateau_mode: str = 'max', + step_on_epochs: bool = True, + updates_per_epoch: int = 0, +): + t_initial = num_epochs + warmup_t = warmup_epochs + decay_t = decay_epochs + cooldown_t = cooldown_epochs + + if not step_on_epochs: + assert updates_per_epoch > 0, 'updates_per_epoch must be set to number of dataloader batches' + t_initial = t_initial * updates_per_epoch + warmup_t = warmup_t * updates_per_epoch + decay_t = decay_t * updates_per_epoch + decay_milestones = [d * updates_per_epoch for d in decay_milestones] + cooldown_t = cooldown_t * updates_per_epoch + + # warmup args + warmup_args = dict( + warmup_lr_init=warmup_lr, + warmup_t=warmup_t, + warmup_prefix=warmup_prefix, + ) + + # setup noise args for supporting schedulers + if noise is not None: + if isinstance(noise, (list, tuple)): + noise_range = [n * t_initial for n in noise] + if len(noise_range) == 1: + noise_range = noise_range[0] + else: + noise_range = noise * t_initial + else: + noise_range = None + noise_args = dict( + noise_range_t=noise_range, + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, + ) + + # setup cycle args for supporting schedulers + cycle_args = dict( + cycle_mul=cycle_mul, + cycle_decay=cycle_decay, + cycle_limit=cycle_limit, + ) + + lr_scheduler = None + if sched == 'cosine': + lr_scheduler = CosineLRScheduler( + optimizer, + t_initial=t_initial, + lr_min=min_lr, + t_in_epochs=step_on_epochs, + **cycle_args, + **warmup_args, + **noise_args, + k_decay=k_decay, + ) + elif sched == 'tanh': + lr_scheduler = TanhLRScheduler( + optimizer, + t_initial=t_initial, + lr_min=min_lr, + t_in_epochs=step_on_epochs, + **cycle_args, + **warmup_args, + **noise_args, + ) + elif sched == 'step': + lr_scheduler = StepLRScheduler( + optimizer, + decay_t=decay_t, + decay_rate=decay_rate, + t_in_epochs=step_on_epochs, + **warmup_args, + **noise_args, + ) + elif sched == 'multistep': + lr_scheduler = MultiStepLRScheduler( + optimizer, + decay_t=decay_milestones, + decay_rate=decay_rate, + t_in_epochs=step_on_epochs, + **warmup_args, + **noise_args, + ) + elif sched == 'plateau': + assert step_on_epochs, 'Plateau LR only supports step per epoch.' + warmup_args.pop('warmup_prefix', False) + lr_scheduler = PlateauLRScheduler( + optimizer, + decay_rate=decay_rate, + patience_t=patience_epochs, + cooldown_t=0, + **warmup_args, + lr_min=min_lr, + mode=plateau_mode, + **noise_args, + ) + elif sched == 'poly': + lr_scheduler = PolyLRScheduler( + optimizer, + power=decay_rate, # overloading 'decay_rate' as polynomial power + t_initial=t_initial, + lr_min=min_lr, + t_in_epochs=step_on_epochs, + k_decay=k_decay, + **cycle_args, + **warmup_args, + **noise_args, + ) + + if hasattr(lr_scheduler, 'get_cycle_length'): + # for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown + t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t + if step_on_epochs: + num_epochs = t_with_cycles_and_cooldown + else: + num_epochs = t_with_cycles_and_cooldown // updates_per_epoch + + return lr_scheduler, num_epochs diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/step_lr.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/step_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..70a45a70d4c547be2527f77452c3675a1b05b818 --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/step_lr.py @@ -0,0 +1,61 @@ +""" Step Scheduler + +Basic step LR schedule with warmup, noise. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +import torch + +from .scheduler import Scheduler + + +class StepLRScheduler(Scheduler): + """ + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + decay_t: float, + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=True, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: + super().__init__( + optimizer, + param_group_field="lr", + t_in_epochs=t_in_epochs, + noise_range_t=noise_range_t, + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, + initialize=initialize, + ) + + self.decay_t = decay_t + self.decay_rate = decay_rate + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] + return lrs diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/tanh_lr.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/tanh_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..48acc61b033491dddfcb4c739549fa7f9b64661e --- /dev/null +++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/scheduler/tanh_lr.py @@ -0,0 +1,112 @@ +""" TanH Scheduler + +TanH schedule with warmup, cycle/restarts, noise. + +Hacked together by / Copyright 2021 Ross Wightman +""" +import logging +import math +import numpy as np +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class TanhLRScheduler(Scheduler): + """ + Hyberbolic-Tangent decay with restarts. + This is described in the paper https://arxiv.org/abs/1806.01593 + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + t_initial: int, + lb: float = -7., + ub: float = 3., + lr_min: float = 0., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: + super().__init__( + optimizer, + param_group_field="lr", + t_in_epochs=t_in_epochs, + noise_range_t=noise_range_t, + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, + initialize=initialize, + ) + + assert t_initial > 0 + assert lr_min >= 0 + assert lb < ub + assert cycle_limit >= 0 + assert warmup_t >= 0 + assert warmup_lr_init >= 0 + self.lb = lb + self.ub = ub + self.t_initial = t_initial + self.lr_min = lr_min + self.cycle_mul = cycle_mul + self.cycle_decay = cycle_decay + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + if self.warmup_t: + t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.cycle_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) + t_i = self.cycle_mul ** i * self.t_initial + t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + if i < self.cycle_limit: + gamma = self.cycle_decay ** i + lr_max_values = [v * gamma for v in self.base_values] + + tr = t_curr / t_i + lrs = [ + self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) + for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min for _ in self.base_values] + return lrs + + def get_cycle_length(self, cycles=0): + cycles = max(1, cycles or self.cycle_limit) + if self.cycle_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))