WhiskeyRacoon commited on
Commit
7aaf3f7
·
verified ·
1 Parent(s): 79ae48f

Upload maxxvit.py

Browse files
Files changed (1) hide show
  1. maxxvit.py +1913 -0
maxxvit.py ADDED
@@ -0,0 +1,1913 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ MaxVit and CoAtNet Vision Transformer - CNN Hybrids in PyTorch
2
+
3
+ This is a from-scratch implementation of both CoAtNet and MaxVit in PyTorch.
4
+
5
+ 99% of the implementation was done from papers, however last minute some adjustments were made
6
+ based on the (as yet unfinished?) public code release https://github.com/google-research/maxvit
7
+
8
+ There are multiple sets of models defined for both architectures. Typically, names with a
9
+ `_rw` suffix are my own original configs prior to referencing https://github.com/google-research/maxvit.
10
+ These configs work well and appear to be a bit faster / lower resource than the paper.
11
+
12
+ The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to
13
+ match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match.
14
+
15
+ # FIXME / WARNING
16
+ This impl remains a WIP, some configs and models may vanish or change...
17
+
18
+ Papers:
19
+
20
+ MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697
21
+ @article{tu2022maxvit,
22
+ title={MaxViT: Multi-Axis Vision Transformer},
23
+ author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan and Li, Yinxiao},
24
+ journal={ECCV},
25
+ year={2022},
26
+ }
27
+
28
+ CoAtNet: Marrying Convolution and Attention for All Data Sizes - https://arxiv.org/abs/2106.04803
29
+ @article{DBLP:journals/corr/abs-2106-04803,
30
+ author = {Zihang Dai and Hanxiao Liu and Quoc V. Le and Mingxing Tan},
31
+ title = {CoAtNet: Marrying Convolution and Attention for All Data Sizes},
32
+ journal = {CoRR},
33
+ volume = {abs/2106.04803},
34
+ year = {2021}
35
+ }
36
+
37
+ Hacked together by / Copyright 2022, Ross Wightman
38
+ """
39
+
40
+ import math
41
+ from collections import OrderedDict
42
+ from dataclasses import dataclass, replace, field
43
+ from functools import partial
44
+ from typing import Callable, Optional, Union, Tuple, List
45
+
46
+ import torch
47
+ from torch import nn
48
+ from torch.utils.checkpoint import checkpoint
49
+
50
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
51
+ from .helpers import build_model_with_cfg, checkpoint_seq, named_apply
52
+ from .fx_features import register_notrace_function
53
+ from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm2d, LayerNorm
54
+ from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d
55
+ from .layers import to_2tuple, extend_tuple, make_divisible, _assert
56
+ from .registry import register_model
57
+ from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location
58
+
59
+ __all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit']
60
+
61
+
62
+ def _cfg(url='', **kwargs):
63
+ return {
64
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
65
+ 'crop_pct': 0.95, 'interpolation': 'bicubic',
66
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
67
+ 'first_conv': 'stem.conv1', 'classifier': 'head.fc',
68
+ 'fixed_input_size': True,
69
+ **kwargs
70
+ }
71
+
72
+
73
+ default_cfgs = {
74
+ # Fiddling with configs / defaults / still pretraining
75
+ 'coatnet_pico_rw_224': _cfg(url=''),
76
+ 'coatnet_nano_rw_224': _cfg(
77
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
78
+ crop_pct=0.9),
79
+ 'coatnet_0_rw_224': _cfg(
80
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
81
+ 'coatnet_1_rw_224': _cfg(
82
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
83
+ ),
84
+ 'coatnet_2_rw_224': _cfg(url=''),
85
+ 'coatnet_3_rw_224': _cfg(url=''),
86
+
87
+ # Highly experimental configs
88
+ 'coatnet_bn_0_rw_224': _cfg(
89
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
90
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
91
+ crop_pct=0.95),
92
+ 'coatnet_rmlp_nano_rw_224': _cfg(
93
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
94
+ crop_pct=0.9),
95
+ 'coatnet_rmlp_0_rw_224': _cfg(url=''),
96
+ 'coatnet_rmlp_1_rw_224': _cfg(
97
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
98
+ 'coatnet_rmlp_2_rw_224': _cfg(
99
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'),
100
+ 'coatnet_rmlp_3_rw_224': _cfg(url=''),
101
+ 'coatnet_nano_cc_224': _cfg(url=''),
102
+ 'coatnext_nano_rw_224': _cfg(
103
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth',
104
+ crop_pct=0.9),
105
+
106
+ # Trying to be like the CoAtNet paper configs
107
+ 'coatnet_0_224': _cfg(url=''),
108
+ 'coatnet_1_224': _cfg(url=''),
109
+ 'coatnet_2_224': _cfg(url=''),
110
+ 'coatnet_3_224': _cfg(url=''),
111
+ 'coatnet_4_224': _cfg(url=''),
112
+ 'coatnet_5_224': _cfg(url=''),
113
+
114
+ # Experimental configs
115
+ 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
116
+ 'maxvit_nano_rw_256': _cfg(
117
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth',
118
+ input_size=(3, 256, 256), pool_size=(8, 8)),
119
+ 'maxvit_tiny_rw_224': _cfg(
120
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'),
121
+ 'maxvit_tiny_rw_256': _cfg(
122
+ url='',
123
+ input_size=(3, 256, 256), pool_size=(8, 8)),
124
+ 'maxvit_rmlp_pico_rw_256': _cfg(
125
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth',
126
+ input_size=(3, 256, 256), pool_size=(8, 8)),
127
+ 'maxvit_rmlp_nano_rw_256': _cfg(
128
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
129
+ input_size=(3, 256, 256), pool_size=(8, 8)),
130
+ 'maxvit_rmlp_tiny_rw_256': _cfg(
131
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth',
132
+ input_size=(3, 256, 256), pool_size=(8, 8)),
133
+ 'maxvit_rmlp_small_rw_224': _cfg(
134
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth',
135
+ crop_pct=0.9,
136
+ ),
137
+ 'maxvit_rmlp_small_rw_256': _cfg(
138
+ url='',
139
+ input_size=(3, 256, 256), pool_size=(8, 8)),
140
+
141
+ 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
142
+
143
+ 'maxxvit_rmlp_nano_rw_256': _cfg(
144
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth',
145
+ input_size=(3, 256, 256), pool_size=(8, 8)),
146
+ 'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
147
+ 'maxxvit_rmlp_small_rw_256': _cfg(
148
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth',
149
+ input_size=(3, 256, 256), pool_size=(8, 8)),
150
+
151
+ # Trying to be like the MaxViT paper configs
152
+ 'maxvit_tiny_224': _cfg(url=''),
153
+ 'maxvit_small_224': _cfg(url=''),
154
+ 'maxvit_base_224': _cfg(url=''),
155
+ 'maxvit_large_224': _cfg(url=''),
156
+ 'maxvit_xlarge_224': _cfg(url=''),
157
+ }
158
+
159
+
160
+ @dataclass
161
+ class MaxxVitTransformerCfg:
162
+ dim_head: int = 32
163
+ expand_ratio: float = 4.0
164
+ expand_first: bool = True
165
+ shortcut_bias: bool = True
166
+ attn_bias: bool = True
167
+ attn_drop: float = 0.
168
+ proj_drop: float = 0.
169
+ pool_type: str = 'avg2'
170
+ rel_pos_type: str = 'bias'
171
+ rel_pos_dim: int = 512 # for relative position types w/ MLP
172
+ partition_ratio: int = 32
173
+ window_size: Optional[Tuple[int, int]] = None
174
+ grid_size: Optional[Tuple[int, int]] = None
175
+ init_values: Optional[float] = None
176
+ act_layer: str = 'gelu'
177
+ norm_layer: str = 'layernorm2d'
178
+ norm_layer_cl: str = 'layernorm'
179
+ norm_eps: float = 1e-6
180
+
181
+ def __post_init__(self):
182
+ if self.grid_size is not None:
183
+ self.grid_size = to_2tuple(self.grid_size)
184
+ if self.window_size is not None:
185
+ self.window_size = to_2tuple(self.window_size)
186
+ if self.grid_size is None:
187
+ self.grid_size = self.window_size
188
+
189
+
190
+ @dataclass
191
+ class MaxxVitConvCfg:
192
+ block_type: str = 'mbconv'
193
+ expand_ratio: float = 4.0
194
+ expand_output: bool = True # calculate expansion channels from output (vs input chs)
195
+ kernel_size: int = 3
196
+ group_size: int = 1 # 1 == depthwise
197
+ pre_norm_act: bool = False # activation after pre-norm
198
+ output_bias: bool = True # bias for shortcut + final 1x1 projection conv
199
+ stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
200
+ pool_type: str = 'avg2'
201
+ downsample_pool_type: str = 'avg2'
202
+ attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2
203
+ attn_layer: str = 'se'
204
+ attn_act_layer: str = 'silu'
205
+ attn_ratio: float = 0.25
206
+ init_values: Optional[float] = 1e-6 # for ConvNeXt block, ignored by MBConv
207
+ act_layer: str = 'gelu'
208
+ norm_layer: str = ''
209
+ norm_layer_cl: str = ''
210
+ norm_eps: Optional[float] = None
211
+
212
+ def __post_init__(self):
213
+ # mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args
214
+ assert self.block_type in ('mbconv', 'convnext')
215
+ use_mbconv = self.block_type == 'mbconv'
216
+ if not self.norm_layer:
217
+ self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d'
218
+ if not self.norm_layer_cl and not use_mbconv:
219
+ self.norm_layer_cl = 'layernorm'
220
+ if self.norm_eps is None:
221
+ self.norm_eps = 1e-5 if use_mbconv else 1e-6
222
+ self.downsample_pool_type = self.downsample_pool_type or self.pool_type
223
+
224
+
225
+ @dataclass
226
+ class MaxxVitCfg:
227
+ embed_dim: Tuple[int, ...] = (96, 192, 384, 768)
228
+ depths: Tuple[int, ...] = (2, 3, 5, 2)
229
+ block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T')
230
+ stem_width: Union[int, Tuple[int, int]] = 64
231
+ stem_bias: bool = True
232
+ conv_cfg: MaxxVitConvCfg = field(default_factory=MaxxVitConvCfg)
233
+ transformer_cfg: MaxxVitTransformerCfg = field(default_factory=MaxxVitTransformerCfg)
234
+ weight_init: str = 'vit_eff'
235
+
236
+
237
+ def _rw_coat_cfg(
238
+ stride_mode='pool',
239
+ pool_type='avg2',
240
+ conv_output_bias=False,
241
+ conv_attn_early=False,
242
+ conv_attn_act_layer='relu',
243
+ conv_norm_layer='',
244
+ transformer_shortcut_bias=True,
245
+ transformer_norm_layer='layernorm2d',
246
+ transformer_norm_layer_cl='layernorm',
247
+ init_values=None,
248
+ rel_pos_type='bias',
249
+ rel_pos_dim=512,
250
+ ):
251
+ # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit
252
+ # Common differences for initial timm models:
253
+ # - pre-norm layer in MZBConv included an activation after norm
254
+ # - mbconv expansion calculated from input instead of output chs
255
+ # - mbconv shortcut and final 1x1 conv did not have a bias
256
+ # - SE act layer was relu, not silu
257
+ # - mbconv uses silu in timm, not gelu
258
+ # - expansion in attention block done via output proj, not input proj
259
+ # Variable differences (evolved over training initial models):
260
+ # - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat)
261
+ # - SE attention was between conv2 and norm/act
262
+ # - default to avg pool for mbconv downsample instead of 1x1 or dw conv
263
+ # - transformer block shortcut has no bias
264
+ return dict(
265
+ conv_cfg=MaxxVitConvCfg(
266
+ stride_mode=stride_mode,
267
+ pool_type=pool_type,
268
+ pre_norm_act=True,
269
+ expand_output=False,
270
+ output_bias=conv_output_bias,
271
+ attn_early=conv_attn_early,
272
+ attn_act_layer=conv_attn_act_layer,
273
+ act_layer='silu',
274
+ norm_layer=conv_norm_layer,
275
+ ),
276
+ transformer_cfg=MaxxVitTransformerCfg(
277
+ expand_first=False,
278
+ shortcut_bias=transformer_shortcut_bias,
279
+ pool_type=pool_type,
280
+ init_values=init_values,
281
+ norm_layer=transformer_norm_layer,
282
+ norm_layer_cl=transformer_norm_layer_cl,
283
+ rel_pos_type=rel_pos_type,
284
+ rel_pos_dim=rel_pos_dim,
285
+ ),
286
+ )
287
+
288
+
289
+ def _rw_max_cfg(
290
+ stride_mode='dw',
291
+ pool_type='avg2',
292
+ conv_output_bias=False,
293
+ conv_attn_ratio=1 / 16,
294
+ conv_norm_layer='',
295
+ transformer_norm_layer='layernorm2d',
296
+ transformer_norm_layer_cl='layernorm',
297
+ window_size=None,
298
+ dim_head=32,
299
+ init_values=None,
300
+ rel_pos_type='bias',
301
+ rel_pos_dim=512,
302
+ ):
303
+ # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit
304
+ # Differences of initial timm models:
305
+ # - mbconv expansion calculated from input instead of output chs
306
+ # - mbconv shortcut and final 1x1 conv did not have a bias
307
+ # - mbconv uses silu in timm, not gelu
308
+ # - expansion in attention block done via output proj, not input proj
309
+ return dict(
310
+ conv_cfg=MaxxVitConvCfg(
311
+ stride_mode=stride_mode,
312
+ pool_type=pool_type,
313
+ expand_output=False,
314
+ output_bias=conv_output_bias,
315
+ attn_ratio=conv_attn_ratio,
316
+ act_layer='silu',
317
+ norm_layer=conv_norm_layer,
318
+ ),
319
+ transformer_cfg=MaxxVitTransformerCfg(
320
+ expand_first=False,
321
+ pool_type=pool_type,
322
+ dim_head=dim_head,
323
+ window_size=window_size,
324
+ init_values=init_values,
325
+ norm_layer=transformer_norm_layer,
326
+ norm_layer_cl=transformer_norm_layer_cl,
327
+ rel_pos_type=rel_pos_type,
328
+ rel_pos_dim=rel_pos_dim,
329
+ ),
330
+ )
331
+
332
+
333
+ def _next_cfg(
334
+ stride_mode='dw',
335
+ pool_type='avg2',
336
+ conv_norm_layer='layernorm2d',
337
+ conv_norm_layer_cl='layernorm',
338
+ transformer_norm_layer='layernorm2d',
339
+ transformer_norm_layer_cl='layernorm',
340
+ window_size=None,
341
+ init_values=1e-6,
342
+ rel_pos_type='mlp', # MLP by default for maxxvit
343
+ rel_pos_dim=512,
344
+ ):
345
+ # For experimental models with convnext instead of mbconv
346
+ init_values = to_2tuple(init_values)
347
+ return dict(
348
+ conv_cfg=MaxxVitConvCfg(
349
+ block_type='convnext',
350
+ stride_mode=stride_mode,
351
+ pool_type=pool_type,
352
+ expand_output=False,
353
+ init_values=init_values[0],
354
+ norm_layer=conv_norm_layer,
355
+ norm_layer_cl=conv_norm_layer_cl,
356
+ ),
357
+ transformer_cfg=MaxxVitTransformerCfg(
358
+ expand_first=False,
359
+ pool_type=pool_type,
360
+ window_size=window_size,
361
+ init_values=init_values[1],
362
+ norm_layer=transformer_norm_layer,
363
+ norm_layer_cl=transformer_norm_layer_cl,
364
+ rel_pos_type=rel_pos_type,
365
+ rel_pos_dim=rel_pos_dim,
366
+ ),
367
+ )
368
+
369
+
370
+ model_cfgs = dict(
371
+ # Fiddling with configs / defaults / still pretraining
372
+ coatnet_pico_rw_224=MaxxVitCfg(
373
+ embed_dim=(64, 128, 256, 512),
374
+ depths=(2, 3, 5, 2),
375
+ stem_width=(32, 64),
376
+ **_rw_max_cfg( # using newer max defaults here
377
+ conv_output_bias=True,
378
+ conv_attn_ratio=0.25,
379
+ ),
380
+ ),
381
+ coatnet_nano_rw_224=MaxxVitCfg(
382
+ embed_dim=(64, 128, 256, 512),
383
+ depths=(3, 4, 6, 3),
384
+ stem_width=(32, 64),
385
+ **_rw_max_cfg( # using newer max defaults here
386
+ stride_mode='pool',
387
+ conv_output_bias=True,
388
+ conv_attn_ratio=0.25,
389
+ ),
390
+ ),
391
+ coatnet_0_rw_224=MaxxVitCfg(
392
+ embed_dim=(96, 192, 384, 768),
393
+ depths=(2, 3, 7, 2), # deeper than paper '0' model
394
+ stem_width=(32, 64),
395
+ **_rw_coat_cfg(
396
+ conv_attn_early=True,
397
+ transformer_shortcut_bias=False,
398
+ ),
399
+ ),
400
+ coatnet_1_rw_224=MaxxVitCfg(
401
+ embed_dim=(96, 192, 384, 768),
402
+ depths=(2, 6, 14, 2),
403
+ stem_width=(32, 64),
404
+ **_rw_coat_cfg(
405
+ stride_mode='dw',
406
+ conv_attn_early=True,
407
+ transformer_shortcut_bias=False,
408
+ )
409
+ ),
410
+ coatnet_2_rw_224=MaxxVitCfg(
411
+ embed_dim=(128, 256, 512, 1024),
412
+ depths=(2, 6, 14, 2),
413
+ stem_width=(64, 128),
414
+ **_rw_coat_cfg(
415
+ stride_mode='dw',
416
+ conv_attn_act_layer='silu',
417
+ init_values=1e-6,
418
+ ),
419
+ ),
420
+ coatnet_3_rw_224=MaxxVitCfg(
421
+ embed_dim=(192, 384, 768, 1536),
422
+ depths=(2, 6, 14, 2),
423
+ stem_width=(96, 192),
424
+ **_rw_coat_cfg(
425
+ stride_mode='dw',
426
+ conv_attn_act_layer='silu',
427
+ init_values=1e-6,
428
+ ),
429
+ ),
430
+
431
+ # Highly experimental configs
432
+ coatnet_bn_0_rw_224=MaxxVitCfg(
433
+ embed_dim=(96, 192, 384, 768),
434
+ depths=(2, 3, 7, 2), # deeper than paper '0' model
435
+ stem_width=(32, 64),
436
+ **_rw_coat_cfg(
437
+ stride_mode='dw',
438
+ conv_attn_early=True,
439
+ transformer_shortcut_bias=False,
440
+ transformer_norm_layer='batchnorm2d',
441
+ )
442
+ ),
443
+ coatnet_rmlp_nano_rw_224=MaxxVitCfg(
444
+ embed_dim=(64, 128, 256, 512),
445
+ depths=(3, 4, 6, 3),
446
+ stem_width=(32, 64),
447
+ **_rw_max_cfg(
448
+ conv_output_bias=True,
449
+ conv_attn_ratio=0.25,
450
+ rel_pos_type='mlp',
451
+ rel_pos_dim=384,
452
+ ),
453
+ ),
454
+ coatnet_rmlp_0_rw_224=MaxxVitCfg(
455
+ embed_dim=(96, 192, 384, 768),
456
+ depths=(2, 3, 7, 2), # deeper than paper '0' model
457
+ stem_width=(32, 64),
458
+ **_rw_coat_cfg(
459
+ stride_mode='dw',
460
+ rel_pos_type='mlp',
461
+ ),
462
+ ),
463
+ coatnet_rmlp_1_rw_224=MaxxVitCfg(
464
+ embed_dim=(96, 192, 384, 768),
465
+ depths=(2, 6, 14, 2),
466
+ stem_width=(32, 64),
467
+ **_rw_coat_cfg(
468
+ pool_type='max',
469
+ conv_attn_early=True,
470
+ transformer_shortcut_bias=False,
471
+ rel_pos_type='mlp',
472
+ rel_pos_dim=384, # was supposed to be 512, woops
473
+ ),
474
+ ),
475
+ coatnet_rmlp_2_rw_224=MaxxVitCfg(
476
+ embed_dim=(128, 256, 512, 1024),
477
+ depths=(2, 6, 14, 2),
478
+ stem_width=(64, 128),
479
+ **_rw_coat_cfg(
480
+ stride_mode='dw',
481
+ conv_attn_act_layer='silu',
482
+ init_values=1e-6,
483
+ rel_pos_type='mlp'
484
+ ),
485
+ ),
486
+ coatnet_rmlp_3_rw_224=MaxxVitCfg(
487
+ embed_dim=(192, 384, 768, 1536),
488
+ depths=(2, 6, 14, 2),
489
+ stem_width=(96, 192),
490
+ **_rw_coat_cfg(
491
+ stride_mode='dw',
492
+ conv_attn_act_layer='silu',
493
+ init_values=1e-6,
494
+ rel_pos_type='mlp'
495
+ ),
496
+ ),
497
+
498
+ coatnet_nano_cc_224=MaxxVitCfg(
499
+ embed_dim=(64, 128, 256, 512),
500
+ depths=(3, 4, 6, 3),
501
+ stem_width=(32, 64),
502
+ block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
503
+ **_rw_coat_cfg(),
504
+ ),
505
+ coatnext_nano_rw_224=MaxxVitCfg(
506
+ embed_dim=(64, 128, 256, 512),
507
+ depths=(3, 4, 6, 3),
508
+ stem_width=(32, 64),
509
+ weight_init='normal',
510
+ **_next_cfg(
511
+ rel_pos_type='bias',
512
+ init_values=(1e-5, None)
513
+ ),
514
+ ),
515
+
516
+ # Trying to be like the CoAtNet paper configs
517
+ coatnet_0_224=MaxxVitCfg(
518
+ embed_dim=(96, 192, 384, 768),
519
+ depths=(2, 3, 5, 2),
520
+ stem_width=64,
521
+ ),
522
+ coatnet_1_224=MaxxVitCfg(
523
+ embed_dim=(96, 192, 384, 768),
524
+ depths=(2, 6, 14, 2),
525
+ stem_width=64,
526
+ ),
527
+ coatnet_2_224=MaxxVitCfg(
528
+ embed_dim=(128, 256, 512, 1024),
529
+ depths=(2, 6, 14, 2),
530
+ stem_width=128,
531
+ ),
532
+ coatnet_3_224=MaxxVitCfg(
533
+ embed_dim=(192, 384, 768, 1536),
534
+ depths=(2, 6, 14, 2),
535
+ stem_width=192,
536
+ ),
537
+ coatnet_4_224=MaxxVitCfg(
538
+ embed_dim=(192, 384, 768, 1536),
539
+ depths=(2, 12, 28, 2),
540
+ stem_width=192,
541
+ ),
542
+ coatnet_5_224=MaxxVitCfg(
543
+ embed_dim=(256, 512, 1280, 2048),
544
+ depths=(2, 12, 28, 2),
545
+ stem_width=192,
546
+ ),
547
+
548
+ # Experimental MaxVit configs
549
+ maxvit_pico_rw_256=MaxxVitCfg(
550
+ embed_dim=(32, 64, 128, 256),
551
+ depths=(2, 2, 5, 2),
552
+ block_type=('M',) * 4,
553
+ stem_width=(24, 32),
554
+ **_rw_max_cfg(),
555
+ ),
556
+ maxvit_nano_rw_256=MaxxVitCfg(
557
+ embed_dim=(64, 128, 256, 512),
558
+ depths=(1, 2, 3, 1),
559
+ block_type=('M',) * 4,
560
+ stem_width=(32, 64),
561
+ **_rw_max_cfg(),
562
+ ),
563
+ maxvit_tiny_rw_224=MaxxVitCfg(
564
+ embed_dim=(64, 128, 256, 512),
565
+ depths=(2, 2, 5, 2),
566
+ block_type=('M',) * 4,
567
+ stem_width=(32, 64),
568
+ **_rw_max_cfg(),
569
+ ),
570
+ maxvit_tiny_rw_256=MaxxVitCfg(
571
+ embed_dim=(64, 128, 256, 512),
572
+ depths=(2, 2, 5, 2),
573
+ block_type=('M',) * 4,
574
+ stem_width=(32, 64),
575
+ **_rw_max_cfg(),
576
+ ),
577
+
578
+ maxvit_rmlp_pico_rw_256=MaxxVitCfg(
579
+ embed_dim=(32, 64, 128, 256),
580
+ depths=(2, 2, 5, 2),
581
+ block_type=('M',) * 4,
582
+ stem_width=(24, 32),
583
+ **_rw_max_cfg(rel_pos_type='mlp'),
584
+ ),
585
+ maxvit_rmlp_nano_rw_256=MaxxVitCfg(
586
+ embed_dim=(64, 128, 256, 512),
587
+ depths=(1, 2, 3, 1),
588
+ block_type=('M',) * 4,
589
+ stem_width=(32, 64),
590
+ **_rw_max_cfg(rel_pos_type='mlp'),
591
+ ),
592
+ maxvit_rmlp_tiny_rw_256=MaxxVitCfg(
593
+ embed_dim=(64, 128, 256, 512),
594
+ depths=(2, 2, 5, 2),
595
+ block_type=('M',) * 4,
596
+ stem_width=(32, 64),
597
+ **_rw_max_cfg(rel_pos_type='mlp'),
598
+ ),
599
+ maxvit_rmlp_small_rw_224=MaxxVitCfg(
600
+ embed_dim=(96, 192, 384, 768),
601
+ depths=(2, 2, 5, 2),
602
+ block_type=('M',) * 4,
603
+ stem_width=(32, 64),
604
+ **_rw_max_cfg(
605
+ rel_pos_type='mlp',
606
+ init_values=1e-6,
607
+ ),
608
+ ),
609
+ maxvit_rmlp_small_rw_256=MaxxVitCfg(
610
+ embed_dim=(96, 192, 384, 768),
611
+ depths=(2, 2, 5, 2),
612
+ block_type=('M',) * 4,
613
+ stem_width=(32, 64),
614
+ **_rw_max_cfg(
615
+ rel_pos_type='mlp',
616
+ init_values=1e-6,
617
+ ),
618
+ ),
619
+
620
+ maxvit_tiny_pm_256=MaxxVitCfg(
621
+ embed_dim=(64, 128, 256, 512),
622
+ depths=(2, 2, 5, 2),
623
+ block_type=('PM',) * 4,
624
+ stem_width=(32, 64),
625
+ **_rw_max_cfg(),
626
+ ),
627
+
628
+ maxxvit_rmlp_nano_rw_256=MaxxVitCfg(
629
+ embed_dim=(64, 128, 256, 512),
630
+ depths=(1, 2, 3, 1),
631
+ block_type=('M',) * 4,
632
+ stem_width=(32, 64),
633
+ weight_init='normal',
634
+ **_next_cfg(),
635
+ ),
636
+ maxxvit_rmlp_tiny_rw_256=MaxxVitCfg(
637
+ embed_dim=(64, 128, 256, 512),
638
+ depths=(2, 2, 5, 2),
639
+ block_type=('M',) * 4,
640
+ stem_width=(32, 64),
641
+ **_next_cfg(),
642
+ ),
643
+ maxxvit_rmlp_small_rw_256=MaxxVitCfg(
644
+ embed_dim=(96, 192, 384, 768),
645
+ depths=(2, 2, 5, 2),
646
+ block_type=('M',) * 4,
647
+ stem_width=(48, 96),
648
+ **_next_cfg(),
649
+ ),
650
+
651
+ # Trying to be like the MaxViT paper configs
652
+ maxvit_tiny_224=MaxxVitCfg(
653
+ embed_dim=(64, 128, 256, 512),
654
+ depths=(2, 2, 5, 2),
655
+ block_type=('M',) * 4,
656
+ stem_width=64,
657
+ ),
658
+ maxvit_small_224=MaxxVitCfg(
659
+ embed_dim=(96, 192, 384, 768),
660
+ depths=(2, 2, 5, 2),
661
+ block_type=('M',) * 4,
662
+ stem_width=64,
663
+ ),
664
+ maxvit_base_224=MaxxVitCfg(
665
+ embed_dim=(96, 192, 384, 768),
666
+ depths=(2, 6, 14, 2),
667
+ block_type=('M',) * 4,
668
+ stem_width=64,
669
+ ),
670
+ maxvit_large_224=MaxxVitCfg(
671
+ embed_dim=(128, 256, 512, 1024),
672
+ depths=(2, 6, 14, 2),
673
+ block_type=('M',) * 4,
674
+ stem_width=128,
675
+ ),
676
+ maxvit_xlarge_224=MaxxVitCfg(
677
+ embed_dim=(192, 384, 768, 1536),
678
+ depths=(2, 6, 14, 2),
679
+ block_type=('M',) * 4,
680
+ stem_width=192,
681
+ ),
682
+
683
+ )
684
+
685
+
686
+ class Attention2d(nn.Module):
687
+ """ multi-head attention for 2D NCHW tensors"""
688
+ def __init__(
689
+ self,
690
+ dim: int,
691
+ dim_out: Optional[int] = None,
692
+ dim_head: int = 32,
693
+ bias: bool = True,
694
+ expand_first: bool = True,
695
+ rel_pos_cls: Callable = None,
696
+ attn_drop: float = 0.,
697
+ proj_drop: float = 0.
698
+ ):
699
+ super().__init__()
700
+ dim_out = dim_out or dim
701
+ dim_attn = dim_out if expand_first else dim
702
+ self.num_heads = dim_attn // dim_head
703
+ self.dim_head = dim_head
704
+ self.scale = dim_head ** -0.5
705
+
706
+ self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
707
+ self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
708
+ self.attn_drop = nn.Dropout(attn_drop)
709
+ self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias)
710
+ self.proj_drop = nn.Dropout(proj_drop)
711
+
712
+ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
713
+ B, C, H, W = x.shape
714
+
715
+ q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2)
716
+
717
+ attn = (q.transpose(-2, -1) @ k) * self.scale
718
+ if self.rel_pos is not None:
719
+ attn = self.rel_pos(attn)
720
+ elif shared_rel_pos is not None:
721
+ attn = attn + shared_rel_pos
722
+ attn = attn.softmax(dim=-1)
723
+ attn = self.attn_drop(attn)
724
+
725
+ x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
726
+ x = self.proj(x)
727
+ x = self.proj_drop(x)
728
+ return x
729
+
730
+
731
+ class AttentionCl(nn.Module):
732
+ """ Channels-last multi-head attention (B, ..., C) """
733
+ def __init__(
734
+ self,
735
+ dim: int,
736
+ dim_out: Optional[int] = None,
737
+ dim_head: int = 32,
738
+ bias: bool = True,
739
+ expand_first: bool = True,
740
+ rel_pos_cls: Callable = None,
741
+ attn_drop: float = 0.,
742
+ proj_drop: float = 0.
743
+ ):
744
+ super().__init__()
745
+ dim_out = dim_out or dim
746
+ dim_attn = dim_out if expand_first and dim_out > dim else dim
747
+ assert dim_attn % dim_head == 0, 'attn dim should be divisible by head_dim'
748
+ self.num_heads = dim_attn // dim_head
749
+ self.dim_head = dim_head
750
+ self.scale = dim_head ** -0.5
751
+
752
+ self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias)
753
+ self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
754
+ self.attn_drop = nn.Dropout(attn_drop)
755
+ self.proj = nn.Linear(dim_attn, dim_out, bias=bias)
756
+ self.proj_drop = nn.Dropout(proj_drop)
757
+
758
+ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
759
+ B = x.shape[0]
760
+ restore_shape = x.shape[:-1]
761
+
762
+ q, k, v = self.qkv(x).view(B, -1, self.num_heads, self.dim_head * 3).transpose(1, 2).chunk(3, dim=3)
763
+
764
+ attn = (q @ k.transpose(-2, -1)) * self.scale
765
+ if self.rel_pos is not None:
766
+ attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
767
+ elif shared_rel_pos is not None:
768
+ attn = attn + shared_rel_pos
769
+ attn = attn.softmax(dim=-1)
770
+ attn = self.attn_drop(attn)
771
+
772
+ x = (attn @ v).transpose(1, 2).reshape(restore_shape + (-1,))
773
+ x = self.proj(x)
774
+ x = self.proj_drop(x)
775
+ return x
776
+
777
+
778
+ class LayerScale(nn.Module):
779
+ def __init__(self, dim, init_values=1e-5, inplace=False):
780
+ super().__init__()
781
+ self.inplace = inplace
782
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
783
+
784
+ def forward(self, x):
785
+ gamma = self.gamma
786
+ return x.mul_(gamma) if self.inplace else x * gamma
787
+
788
+
789
+ class LayerScale2d(nn.Module):
790
+ def __init__(self, dim, init_values=1e-5, inplace=False):
791
+ super().__init__()
792
+ self.inplace = inplace
793
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
794
+
795
+ def forward(self, x):
796
+ gamma = self.gamma.view(1, -1, 1, 1)
797
+ return x.mul_(gamma) if self.inplace else x * gamma
798
+
799
+
800
+ class Downsample2d(nn.Module):
801
+ """ A downsample pooling module supporting several maxpool and avgpool modes
802
+ * 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1
803
+ * 'max2' - MaxPool2d w/ kernel_size = stride = 2
804
+ * 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1
805
+ * 'avg2' - AvgPool2d w/ kernel_size = stride = 2
806
+ """
807
+
808
+ def __init__(
809
+ self,
810
+ dim: int,
811
+ dim_out: int,
812
+ pool_type: str = 'avg2',
813
+ bias: bool = True,
814
+ ):
815
+ super().__init__()
816
+ assert pool_type in ('max', 'max2', 'avg', 'avg2')
817
+ if pool_type == 'max':
818
+ self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
819
+ elif pool_type == 'max2':
820
+ self.pool = nn.MaxPool2d(2) # kernel_size == stride == 2
821
+ elif pool_type == 'avg':
822
+ self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
823
+ else:
824
+ self.pool = nn.AvgPool2d(2) # kernel_size == stride == 2
825
+
826
+ if dim != dim_out:
827
+ self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias)
828
+ else:
829
+ self.expand = nn.Identity()
830
+
831
+ def forward(self, x):
832
+ x = self.pool(x) # spatial downsample
833
+ x = self.expand(x) # expand chs
834
+ return x
835
+
836
+
837
+ def _init_transformer(module, name, scheme=''):
838
+ if isinstance(module, (nn.Conv2d, nn.Linear)):
839
+ if scheme == 'normal':
840
+ nn.init.normal_(module.weight, std=.02)
841
+ if module.bias is not None:
842
+ nn.init.zeros_(module.bias)
843
+ elif scheme == 'trunc_normal':
844
+ trunc_normal_tf_(module.weight, std=.02)
845
+ if module.bias is not None:
846
+ nn.init.zeros_(module.bias)
847
+ elif scheme == 'xavier_normal':
848
+ nn.init.xavier_normal_(module.weight)
849
+ if module.bias is not None:
850
+ nn.init.zeros_(module.bias)
851
+ else:
852
+ # vit like
853
+ nn.init.xavier_uniform_(module.weight)
854
+ if module.bias is not None:
855
+ if 'mlp' in name:
856
+ nn.init.normal_(module.bias, std=1e-6)
857
+ else:
858
+ nn.init.zeros_(module.bias)
859
+
860
+
861
+ class TransformerBlock2d(nn.Module):
862
+ """ Transformer block with 2D downsampling
863
+ '2D' NCHW tensor layout
864
+
865
+ Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW
866
+ for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs.
867
+
868
+ This impl was faster on TPU w/ PT XLA than the 1D experiment.
869
+ """
870
+
871
+ def __init__(
872
+ self,
873
+ dim: int,
874
+ dim_out: int,
875
+ stride: int = 1,
876
+ rel_pos_cls: Callable = None,
877
+ cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
878
+ drop_path: float = 0.,
879
+ ):
880
+ super().__init__()
881
+ norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
882
+ act_layer = get_act_layer(cfg.act_layer)
883
+
884
+ if stride == 2:
885
+ self.shortcut = Downsample2d(dim, dim_out, pool_type=cfg.pool_type, bias=cfg.shortcut_bias)
886
+ self.norm1 = nn.Sequential(OrderedDict([
887
+ ('norm', norm_layer(dim)),
888
+ ('down', Downsample2d(dim, dim, pool_type=cfg.pool_type)),
889
+ ]))
890
+ else:
891
+ assert dim == dim_out
892
+ self.shortcut = nn.Identity()
893
+ self.norm1 = norm_layer(dim)
894
+
895
+ self.attn = Attention2d(
896
+ dim,
897
+ dim_out,
898
+ dim_head=cfg.dim_head,
899
+ expand_first=cfg.expand_first,
900
+ bias=cfg.attn_bias,
901
+ rel_pos_cls=rel_pos_cls,
902
+ attn_drop=cfg.attn_drop,
903
+ proj_drop=cfg.proj_drop
904
+ )
905
+ self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
906
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
907
+
908
+ self.norm2 = norm_layer(dim_out)
909
+ self.mlp = ConvMlp(
910
+ in_features=dim_out,
911
+ hidden_features=int(dim_out * cfg.expand_ratio),
912
+ act_layer=act_layer,
913
+ drop=cfg.proj_drop)
914
+ self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
915
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
916
+
917
+ def init_weights(self, scheme=''):
918
+ named_apply(partial(_init_transformer, scheme=scheme), self)
919
+
920
+ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
921
+ x = self.shortcut(x) + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos)))
922
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
923
+ return x
924
+
925
+
926
+ def _init_conv(module, name, scheme=''):
927
+ if isinstance(module, nn.Conv2d):
928
+ if scheme == 'normal':
929
+ nn.init.normal_(module.weight, std=.02)
930
+ if module.bias is not None:
931
+ nn.init.zeros_(module.bias)
932
+ elif scheme == 'trunc_normal':
933
+ trunc_normal_tf_(module.weight, std=.02)
934
+ if module.bias is not None:
935
+ nn.init.zeros_(module.bias)
936
+ elif scheme == 'xavier_normal':
937
+ nn.init.xavier_normal_(module.weight)
938
+ if module.bias is not None:
939
+ nn.init.zeros_(module.bias)
940
+ else:
941
+ # efficientnet like
942
+ fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
943
+ fan_out //= module.groups
944
+ nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
945
+ if module.bias is not None:
946
+ nn.init.zeros_(module.bias)
947
+
948
+
949
+ def num_groups(group_size, channels):
950
+ if not group_size: # 0 or None
951
+ return 1 # normal conv with 1 group
952
+ else:
953
+ # NOTE group_size == 1 -> depthwise conv
954
+ assert channels % group_size == 0
955
+ return channels // group_size
956
+
957
+
958
+ class MbConvBlock(nn.Module):
959
+ """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)
960
+ """
961
+ def __init__(
962
+ self,
963
+ in_chs: int,
964
+ out_chs: int,
965
+ stride: int = 1,
966
+ dilation: Tuple[int, int] = (1, 1),
967
+ cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
968
+ drop_path: float = 0.
969
+ ):
970
+ super(MbConvBlock, self).__init__()
971
+ norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps)
972
+ mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio)
973
+ groups = num_groups(cfg.group_size, mid_chs)
974
+
975
+ if stride == 2:
976
+ self.shortcut = Downsample2d(in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias)
977
+ else:
978
+ self.shortcut = nn.Identity()
979
+
980
+ assert cfg.stride_mode in ('pool', '1x1', 'dw')
981
+ stride_pool, stride_1, stride_2 = 1, 1, 1
982
+ if cfg.stride_mode == 'pool':
983
+ # NOTE this is not described in paper, experiment to find faster option that doesn't stride in 1x1
984
+ stride_pool, dilation_2 = stride, dilation[1]
985
+ # FIXME handle dilation of avg pool
986
+ elif cfg.stride_mode == '1x1':
987
+ # NOTE I don't like this option described in paper, 1x1 w/ stride throws info away
988
+ stride_1, dilation_2 = stride, dilation[1]
989
+ else:
990
+ stride_2, dilation_2 = stride, dilation[0]
991
+
992
+ self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act)
993
+ if stride_pool > 1:
994
+ self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type)
995
+ else:
996
+ self.down = nn.Identity()
997
+ self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1)
998
+ self.norm1 = norm_act_layer(mid_chs)
999
+
1000
+ self.conv2_kxk = create_conv2d(
1001
+ mid_chs, mid_chs, cfg.kernel_size, stride=stride_2, dilation=dilation_2, groups=groups)
1002
+
1003
+ attn_kwargs = {}
1004
+ if isinstance(cfg.attn_layer, str):
1005
+ if cfg.attn_layer == 'se' or cfg.attn_layer == 'eca':
1006
+ attn_kwargs['act_layer'] = cfg.attn_act_layer
1007
+ attn_kwargs['rd_channels'] = int(cfg.attn_ratio * (out_chs if cfg.expand_output else mid_chs))
1008
+
1009
+ # two different orderings for SE and norm2 (due to some weights and trials using SE before norm2)
1010
+ if cfg.attn_early:
1011
+ self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs)
1012
+ self.norm2 = norm_act_layer(mid_chs)
1013
+ self.se = None
1014
+ else:
1015
+ self.se_early = None
1016
+ self.norm2 = norm_act_layer(mid_chs)
1017
+ self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs)
1018
+
1019
+ self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias)
1020
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
1021
+
1022
+ def init_weights(self, scheme=''):
1023
+ named_apply(partial(_init_conv, scheme=scheme), self)
1024
+
1025
+ def forward(self, x):
1026
+ shortcut = self.shortcut(x)
1027
+ x = self.pre_norm(x)
1028
+ x = self.down(x)
1029
+
1030
+ # 1x1 expansion conv & norm-act
1031
+ x = self.conv1_1x1(x)
1032
+ x = self.norm1(x)
1033
+
1034
+ # depthwise / grouped 3x3 conv w/ SE (or other) channel attention & norm-act
1035
+ x = self.conv2_kxk(x)
1036
+ if self.se_early is not None:
1037
+ x = self.se_early(x)
1038
+ x = self.norm2(x)
1039
+ if self.se is not None:
1040
+ x = self.se(x)
1041
+
1042
+ # 1x1 linear projection to output width
1043
+ x = self.conv3_1x1(x)
1044
+ x = self.drop_path(x) + shortcut
1045
+ return x
1046
+
1047
+
1048
+ class ConvNeXtBlock(nn.Module):
1049
+ """ ConvNeXt Block
1050
+ """
1051
+
1052
+ def __init__(
1053
+ self,
1054
+ in_chs: int,
1055
+ out_chs: Optional[int] = None,
1056
+ kernel_size: int = 7,
1057
+ stride: int = 1,
1058
+ dilation: Tuple[int, int] = (1, 1),
1059
+ cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
1060
+ conv_mlp: bool = True,
1061
+ drop_path: float = 0.
1062
+ ):
1063
+ super().__init__()
1064
+ out_chs = out_chs or in_chs
1065
+ act_layer = get_act_layer(cfg.act_layer)
1066
+ if conv_mlp:
1067
+ norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
1068
+ mlp_layer = ConvMlp
1069
+ else:
1070
+ assert 'layernorm' in cfg.norm_layer
1071
+ norm_layer = LayerNorm
1072
+ mlp_layer = Mlp
1073
+ self.use_conv_mlp = conv_mlp
1074
+
1075
+ if stride == 2:
1076
+ self.shortcut = Downsample2d(in_chs, out_chs)
1077
+ elif in_chs != out_chs:
1078
+ self.shortcut = nn.Conv2d(in_chs, out_chs, kernel_size=1, bias=cfg.output_bias)
1079
+ else:
1080
+ self.shortcut = nn.Identity()
1081
+
1082
+ assert cfg.stride_mode in ('pool', 'dw')
1083
+ stride_pool, stride_dw = 1, 1
1084
+ # FIXME handle dilation?
1085
+ if cfg.stride_mode == 'pool':
1086
+ stride_pool = stride
1087
+ else:
1088
+ stride_dw = stride
1089
+
1090
+ if stride_pool == 2:
1091
+ self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type)
1092
+ else:
1093
+ self.down = nn.Identity()
1094
+
1095
+ self.conv_dw = create_conv2d(
1096
+ in_chs, out_chs, kernel_size=kernel_size, stride=stride_dw, dilation=dilation[1],
1097
+ depthwise=True, bias=cfg.output_bias)
1098
+ self.norm = norm_layer(out_chs)
1099
+ self.mlp = mlp_layer(out_chs, int(cfg.expand_ratio * out_chs), bias=cfg.output_bias, act_layer=act_layer)
1100
+ if conv_mlp:
1101
+ self.ls = LayerScale2d(out_chs, cfg.init_values) if cfg.init_values else nn.Identity()
1102
+ else:
1103
+ self.ls = LayerScale(out_chs, cfg.init_values) if cfg.init_values else nn.Identity()
1104
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
1105
+
1106
+ def forward(self, x):
1107
+ shortcut = self.shortcut(x)
1108
+ x = self.down(x)
1109
+ x = self.conv_dw(x)
1110
+ if self.use_conv_mlp:
1111
+ x = self.norm(x)
1112
+ x = self.mlp(x)
1113
+ x = self.ls(x)
1114
+ else:
1115
+ x = x.permute(0, 2, 3, 1)
1116
+ x = self.norm(x)
1117
+ x = self.mlp(x)
1118
+ x = self.ls(x)
1119
+ x = x.permute(0, 3, 1, 2)
1120
+
1121
+ x = self.drop_path(x) + shortcut
1122
+ return x
1123
+
1124
+
1125
+ def window_partition(x, window_size: List[int]):
1126
+ B, H, W, C = x.shape
1127
+ _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
1128
+ _assert(W % window_size[1] == 0, '')
1129
+ x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
1130
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
1131
+ return windows
1132
+
1133
+
1134
+ @register_notrace_function # reason: int argument is a Proxy
1135
+ def window_reverse(windows, window_size: List[int], img_size: List[int]):
1136
+ H, W = img_size
1137
+ C = windows.shape[-1]
1138
+ x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
1139
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
1140
+ return x
1141
+
1142
+
1143
+ def grid_partition(x, grid_size: List[int]):
1144
+ B, H, W, C = x.shape
1145
+ _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
1146
+ _assert(W % grid_size[1] == 0, '')
1147
+ x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C)
1148
+ windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C)
1149
+ return windows
1150
+
1151
+
1152
+ @register_notrace_function # reason: int argument is a Proxy
1153
+ def grid_reverse(windows, grid_size: List[int], img_size: List[int]):
1154
+ H, W = img_size
1155
+ C = windows.shape[-1]
1156
+ x = windows.view(-1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C)
1157
+ x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C)
1158
+ return x
1159
+
1160
+
1161
+ def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size):
1162
+ rel_pos_cls = None
1163
+ if cfg.rel_pos_type == 'mlp':
1164
+ rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim)
1165
+ elif cfg.rel_pos_type == 'bias':
1166
+ rel_pos_cls = partial(RelPosBias, window_size=window_size)
1167
+ return rel_pos_cls
1168
+
1169
+
1170
+ class PartitionAttentionCl(nn.Module):
1171
+ """ Grid or Block partition + Attn + FFN.
1172
+ NxC 'channels last' tensor layout.
1173
+ """
1174
+
1175
+ def __init__(
1176
+ self,
1177
+ dim: int,
1178
+ partition_type: str = 'block',
1179
+ cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
1180
+ drop_path: float = 0.,
1181
+ ):
1182
+ super().__init__()
1183
+ norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last
1184
+ act_layer = get_act_layer(cfg.act_layer)
1185
+
1186
+ self.partition_block = partition_type == 'block'
1187
+ self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size)
1188
+ rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)
1189
+
1190
+ self.norm1 = norm_layer(dim)
1191
+ self.attn = AttentionCl(
1192
+ dim,
1193
+ dim,
1194
+ dim_head=cfg.dim_head,
1195
+ bias=cfg.attn_bias,
1196
+ rel_pos_cls=rel_pos_cls,
1197
+ attn_drop=cfg.attn_drop,
1198
+ proj_drop=cfg.proj_drop,
1199
+ )
1200
+ self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
1201
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
1202
+
1203
+ self.norm2 = norm_layer(dim)
1204
+ self.mlp = Mlp(
1205
+ in_features=dim,
1206
+ hidden_features=int(dim * cfg.expand_ratio),
1207
+ act_layer=act_layer,
1208
+ drop=cfg.proj_drop)
1209
+ self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
1210
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
1211
+
1212
+ def _partition_attn(self, x):
1213
+ img_size = x.shape[1:3]
1214
+ if self.partition_block:
1215
+ partitioned = window_partition(x, self.partition_size)
1216
+ else:
1217
+ partitioned = grid_partition(x, self.partition_size)
1218
+
1219
+ partitioned = self.attn(partitioned)
1220
+
1221
+ if self.partition_block:
1222
+ x = window_reverse(partitioned, self.partition_size, img_size)
1223
+ else:
1224
+ x = grid_reverse(partitioned, self.partition_size, img_size)
1225
+ return x
1226
+
1227
+ def forward(self, x):
1228
+ x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
1229
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
1230
+ return x
1231
+
1232
+
1233
+ class ParallelPartitionAttention(nn.Module):
1234
+ """ Experimental. Grid and Block partition + single FFN
1235
+ NxC tensor layout.
1236
+ """
1237
+
1238
+ def __init__(
1239
+ self,
1240
+ dim: int,
1241
+ cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
1242
+ drop_path: float = 0.,
1243
+ ):
1244
+ super().__init__()
1245
+ assert dim % 2 == 0
1246
+ norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last
1247
+ act_layer = get_act_layer(cfg.act_layer)
1248
+
1249
+ assert cfg.window_size == cfg.grid_size
1250
+ self.partition_size = to_2tuple(cfg.window_size)
1251
+ rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)
1252
+
1253
+ self.norm1 = norm_layer(dim)
1254
+ self.attn_block = AttentionCl(
1255
+ dim,
1256
+ dim // 2,
1257
+ dim_head=cfg.dim_head,
1258
+ bias=cfg.attn_bias,
1259
+ rel_pos_cls=rel_pos_cls,
1260
+ attn_drop=cfg.attn_drop,
1261
+ proj_drop=cfg.proj_drop,
1262
+ )
1263
+ self.attn_grid = AttentionCl(
1264
+ dim,
1265
+ dim // 2,
1266
+ dim_head=cfg.dim_head,
1267
+ bias=cfg.attn_bias,
1268
+ rel_pos_cls=rel_pos_cls,
1269
+ attn_drop=cfg.attn_drop,
1270
+ proj_drop=cfg.proj_drop,
1271
+ )
1272
+ self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
1273
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
1274
+
1275
+ self.norm2 = norm_layer(dim)
1276
+ self.mlp = Mlp(
1277
+ in_features=dim,
1278
+ hidden_features=int(dim * cfg.expand_ratio),
1279
+ out_features=dim,
1280
+ act_layer=act_layer,
1281
+ drop=cfg.proj_drop)
1282
+ self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
1283
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
1284
+
1285
+ def _partition_attn(self, x):
1286
+ img_size = x.shape[1:3]
1287
+
1288
+ partitioned_block = window_partition(x, self.partition_size)
1289
+ partitioned_block = self.attn_block(partitioned_block)
1290
+ x_window = window_reverse(partitioned_block, self.partition_size, img_size)
1291
+
1292
+ partitioned_grid = grid_partition(x, self.partition_size)
1293
+ partitioned_grid = self.attn_grid(partitioned_grid)
1294
+ x_grid = grid_reverse(partitioned_grid, self.partition_size, img_size)
1295
+
1296
+ return torch.cat([x_window, x_grid], dim=-1)
1297
+
1298
+ def forward(self, x):
1299
+ x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
1300
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
1301
+ return x
1302
+
1303
+
1304
+ def window_partition_nchw(x, window_size: List[int]):
1305
+ B, C, H, W = x.shape
1306
+ _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
1307
+ _assert(W % window_size[1] == 0, '')
1308
+ x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1])
1309
+ windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1])
1310
+ return windows
1311
+
1312
+
1313
+ @register_notrace_function # reason: int argument is a Proxy
1314
+ def window_reverse_nchw(windows, window_size: List[int], img_size: List[int]):
1315
+ H, W = img_size
1316
+ C = windows.shape[1]
1317
+ x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1])
1318
+ x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W)
1319
+ return x
1320
+
1321
+
1322
+ def grid_partition_nchw(x, grid_size: List[int]):
1323
+ B, C, H, W = x.shape
1324
+ _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
1325
+ _assert(W % grid_size[1] == 0, '')
1326
+ x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1])
1327
+ windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1])
1328
+ return windows
1329
+
1330
+
1331
+ @register_notrace_function # reason: int argument is a Proxy
1332
+ def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]):
1333
+ H, W = img_size
1334
+ C = windows.shape[1]
1335
+ x = windows.view(-1, H // grid_size[0], W // grid_size[1], C, grid_size[0], grid_size[1])
1336
+ x = x.permute(0, 3, 4, 1, 5, 2).contiguous().view(-1, C, H, W)
1337
+ return x
1338
+
1339
+
1340
+ class PartitionAttention2d(nn.Module):
1341
+ """ Grid or Block partition + Attn + FFN
1342
+
1343
+ '2D' NCHW tensor layout.
1344
+ """
1345
+
1346
+ def __init__(
1347
+ self,
1348
+ dim: int,
1349
+ partition_type: str = 'block',
1350
+ cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
1351
+ drop_path: float = 0.,
1352
+ ):
1353
+ super().__init__()
1354
+ norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) # NOTE this block is channels-last
1355
+ act_layer = get_act_layer(cfg.act_layer)
1356
+
1357
+ self.partition_block = partition_type == 'block'
1358
+ self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size)
1359
+ rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)
1360
+
1361
+ self.norm1 = norm_layer(dim)
1362
+ self.attn = Attention2d(
1363
+ dim,
1364
+ dim,
1365
+ dim_head=cfg.dim_head,
1366
+ bias=cfg.attn_bias,
1367
+ rel_pos_cls=rel_pos_cls,
1368
+ attn_drop=cfg.attn_drop,
1369
+ proj_drop=cfg.proj_drop,
1370
+ )
1371
+ self.ls1 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
1372
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
1373
+
1374
+ self.norm2 = norm_layer(dim)
1375
+ self.mlp = ConvMlp(
1376
+ in_features=dim,
1377
+ hidden_features=int(dim * cfg.expand_ratio),
1378
+ act_layer=act_layer,
1379
+ drop=cfg.proj_drop)
1380
+ self.ls2 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
1381
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
1382
+
1383
+ def _partition_attn(self, x):
1384
+ img_size = x.shape[-2:]
1385
+ if self.partition_block:
1386
+ partitioned = window_partition_nchw(x, self.partition_size)
1387
+ else:
1388
+ partitioned = grid_partition_nchw(x, self.partition_size)
1389
+
1390
+ partitioned = self.attn(partitioned)
1391
+
1392
+ if self.partition_block:
1393
+ x = window_reverse_nchw(partitioned, self.partition_size, img_size)
1394
+ else:
1395
+ x = grid_reverse_nchw(partitioned, self.partition_size, img_size)
1396
+ return x
1397
+
1398
+ def forward(self, x):
1399
+ x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
1400
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
1401
+ return x
1402
+
1403
+
1404
+ class MaxxVitBlock(nn.Module):
1405
+ """ MaxVit conv, window partition + FFN , grid partition + FFN
1406
+ """
1407
+
1408
+ def __init__(
1409
+ self,
1410
+ dim: int,
1411
+ dim_out: int,
1412
+ stride: int = 1,
1413
+ conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
1414
+ transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
1415
+ use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU
1416
+ drop_path: float = 0.,
1417
+ ):
1418
+ super().__init__()
1419
+
1420
+ conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
1421
+ self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
1422
+
1423
+ attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
1424
+ partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl
1425
+ self.nchw_attn = use_nchw_attn
1426
+ self.attn_block = partition_layer(**attn_kwargs)
1427
+ self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
1428
+
1429
+ def init_weights(self, scheme=''):
1430
+ named_apply(partial(_init_transformer, scheme=scheme), self.attn_block)
1431
+ named_apply(partial(_init_transformer, scheme=scheme), self.attn_grid)
1432
+ named_apply(partial(_init_conv, scheme=scheme), self.conv)
1433
+
1434
+ def forward(self, x):
1435
+ # NCHW format
1436
+ x = self.conv(x)
1437
+
1438
+ if not self.nchw_attn:
1439
+ x = x.permute(0, 2, 3, 1) # to NHWC (channels-last)
1440
+ x = self.attn_block(x)
1441
+ x = self.attn_grid(x)
1442
+ if not self.nchw_attn:
1443
+ x = x.permute(0, 3, 1, 2) # back to NCHW
1444
+ return x
1445
+
1446
+
1447
+ class ParallelMaxxVitBlock(nn.Module):
1448
+ """ MaxVit block with parallel cat(window + grid), one FF
1449
+ Experimental timm block.
1450
+ """
1451
+
1452
+ def __init__(
1453
+ self,
1454
+ dim,
1455
+ dim_out,
1456
+ stride=1,
1457
+ num_conv=2,
1458
+ conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
1459
+ transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
1460
+ drop_path=0.,
1461
+ ):
1462
+ super().__init__()
1463
+
1464
+ conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
1465
+ if num_conv > 1:
1466
+ convs = [conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)]
1467
+ convs += [conv_cls(dim_out, dim_out, cfg=conv_cfg, drop_path=drop_path)] * (num_conv - 1)
1468
+ self.conv = nn.Sequential(*convs)
1469
+ else:
1470
+ self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
1471
+ self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
1472
+
1473
+ def init_weights(self, scheme=''):
1474
+ named_apply(partial(_init_transformer, scheme=scheme), self.attn)
1475
+ named_apply(partial(_init_conv, scheme=scheme), self.conv)
1476
+
1477
+ def forward(self, x):
1478
+ x = self.conv(x)
1479
+ x = x.permute(0, 2, 3, 1)
1480
+ x = self.attn(x)
1481
+ x = x.permute(0, 3, 1, 2)
1482
+ return x
1483
+
1484
+
1485
+ class MaxxVitStage(nn.Module):
1486
+ def __init__(
1487
+ self,
1488
+ in_chs: int,
1489
+ out_chs: int,
1490
+ stride: int = 2,
1491
+ depth: int = 4,
1492
+ feat_size: Tuple[int, int] = (14, 14),
1493
+ block_types: Union[str, Tuple[str]] = 'C',
1494
+ transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
1495
+ conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
1496
+ drop_path: Union[float, List[float]] = 0.,
1497
+ ):
1498
+ super().__init__()
1499
+ self.grad_checkpointing = False
1500
+
1501
+ block_types = extend_tuple(block_types, depth)
1502
+ blocks = []
1503
+ for i, t in enumerate(block_types):
1504
+ block_stride = stride if i == 0 else 1
1505
+ assert t in ('C', 'T', 'M', 'PM')
1506
+ if t == 'C':
1507
+ conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
1508
+ blocks += [conv_cls(
1509
+ in_chs,
1510
+ out_chs,
1511
+ stride=block_stride,
1512
+ cfg=conv_cfg,
1513
+ drop_path=drop_path[i],
1514
+ )]
1515
+ elif t == 'T':
1516
+ rel_pos_cls = get_rel_pos_cls(transformer_cfg, feat_size)
1517
+ blocks += [TransformerBlock2d(
1518
+ in_chs,
1519
+ out_chs,
1520
+ stride=block_stride,
1521
+ rel_pos_cls=rel_pos_cls,
1522
+ cfg=transformer_cfg,
1523
+ drop_path=drop_path[i],
1524
+ )]
1525
+ elif t == 'M':
1526
+ blocks += [MaxxVitBlock(
1527
+ in_chs,
1528
+ out_chs,
1529
+ stride=block_stride,
1530
+ conv_cfg=conv_cfg,
1531
+ transformer_cfg=transformer_cfg,
1532
+ drop_path=drop_path[i],
1533
+ )]
1534
+ elif t == 'PM':
1535
+ blocks += [ParallelMaxxVitBlock(
1536
+ in_chs,
1537
+ out_chs,
1538
+ stride=block_stride,
1539
+ conv_cfg=conv_cfg,
1540
+ transformer_cfg=transformer_cfg,
1541
+ drop_path=drop_path[i],
1542
+ )]
1543
+ in_chs = out_chs
1544
+ self.blocks = nn.Sequential(*blocks)
1545
+
1546
+ def forward(self, x):
1547
+ if self.grad_checkpointing and not torch.jit.is_scripting():
1548
+ x = checkpoint_seq(self.blocks, x)
1549
+ else:
1550
+ x = self.blocks(x)
1551
+ return x
1552
+
1553
+
1554
+ class Stem(nn.Module):
1555
+
1556
+ def __init__(
1557
+ self,
1558
+ in_chs: int,
1559
+ out_chs: int,
1560
+ kernel_size: int = 3,
1561
+ act_layer: str = 'gelu',
1562
+ norm_layer: str = 'batchnorm2d',
1563
+ norm_eps: float = 1e-5,
1564
+ ):
1565
+ super().__init__()
1566
+ if not isinstance(out_chs, (list, tuple)):
1567
+ out_chs = to_2tuple(out_chs)
1568
+
1569
+ norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
1570
+ self.out_chs = out_chs[-1]
1571
+ self.stride = 2
1572
+
1573
+ self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2)
1574
+ self.norm1 = norm_act_layer(out_chs[0])
1575
+ self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1)
1576
+
1577
+ def init_weights(self, scheme=''):
1578
+ named_apply(partial(_init_conv, scheme=scheme), self)
1579
+
1580
+ def forward(self, x):
1581
+ x = self.conv1(x)
1582
+ x = self.norm1(x)
1583
+ x = self.conv2(x)
1584
+ return x
1585
+
1586
+
1587
+ def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
1588
+ if cfg.window_size is not None:
1589
+ assert cfg.grid_size
1590
+ return cfg
1591
+ partition_size = img_size[0] // cfg.partition_ratio, img_size[1] // cfg.partition_ratio
1592
+ cfg = replace(cfg, window_size=partition_size, grid_size=partition_size)
1593
+ return cfg
1594
+
1595
+
1596
+ class MaxxVit(nn.Module):
1597
+ """ CoaTNet + MaxVit base model.
1598
+
1599
+ Highly configurable for different block compositions, tensor layouts, pooling types.
1600
+ """
1601
+
1602
+ def __init__(
1603
+ self,
1604
+ cfg: MaxxVitCfg,
1605
+ img_size: Union[int, Tuple[int, int]] = 224,
1606
+ in_chans: int = 3,
1607
+ num_classes: int = 1000,
1608
+ global_pool: str = 'avg',
1609
+ drop_rate: float = 0.,
1610
+ drop_path_rate: float = 0.
1611
+ ):
1612
+ super().__init__()
1613
+ img_size = to_2tuple(img_size)
1614
+ transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
1615
+ self.num_classes = num_classes
1616
+ self.global_pool = global_pool
1617
+ self.num_features = cfg.embed_dim[-1]
1618
+ self.embed_dim = cfg.embed_dim
1619
+ self.drop_rate = drop_rate
1620
+ self.grad_checkpointing = False
1621
+
1622
+ self.stem = Stem(
1623
+ in_chs=in_chans,
1624
+ out_chs=cfg.stem_width,
1625
+ act_layer=cfg.conv_cfg.act_layer,
1626
+ norm_layer=cfg.conv_cfg.norm_layer,
1627
+ norm_eps=cfg.conv_cfg.norm_eps,
1628
+ )
1629
+
1630
+ stride = self.stem.stride
1631
+ feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))])
1632
+
1633
+ num_stages = len(cfg.embed_dim)
1634
+ assert len(cfg.depths) == num_stages
1635
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
1636
+ in_chs = self.stem.out_chs
1637
+ stages = []
1638
+ for i in range(num_stages):
1639
+ stage_stride = 2
1640
+ out_chs = cfg.embed_dim[i]
1641
+ feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size])
1642
+ stages += [MaxxVitStage(
1643
+ in_chs,
1644
+ out_chs,
1645
+ depth=cfg.depths[i],
1646
+ block_types=cfg.block_type[i],
1647
+ conv_cfg=cfg.conv_cfg,
1648
+ transformer_cfg=transformer_cfg,
1649
+ feat_size=feat_size,
1650
+ drop_path=dpr[i],
1651
+ )]
1652
+ stride *= stage_stride
1653
+ in_chs = out_chs
1654
+ self.stages = nn.Sequential(*stages)
1655
+
1656
+ final_norm_layer = get_norm_layer(cfg.transformer_cfg.norm_layer)
1657
+ self.norm = final_norm_layer(self.num_features, eps=cfg.transformer_cfg.norm_eps)
1658
+
1659
+ # Classifier head
1660
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
1661
+
1662
+ # Weight init (default PyTorch init works well for AdamW if scheme not set)
1663
+ assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff')
1664
+ if cfg.weight_init:
1665
+ named_apply(partial(self._init_weights, scheme=cfg.weight_init), self)
1666
+
1667
+ def _init_weights(self, module, name, scheme=''):
1668
+ if hasattr(module, 'init_weights'):
1669
+ try:
1670
+ module.init_weights(scheme=scheme)
1671
+ except TypeError:
1672
+ module.init_weights()
1673
+
1674
+ @torch.jit.ignore
1675
+ def no_weight_decay(self):
1676
+ return {
1677
+ k for k, _ in self.named_parameters()
1678
+ if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
1679
+
1680
+ @torch.jit.ignore
1681
+ def group_matcher(self, coarse=False):
1682
+ matcher = dict(
1683
+ stem=r'^stem', # stem and embed
1684
+ blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
1685
+ )
1686
+ return matcher
1687
+
1688
+ @torch.jit.ignore
1689
+ def set_grad_checkpointing(self, enable=True):
1690
+ for s in self.stages:
1691
+ s.grad_checkpointing = enable
1692
+
1693
+ @torch.jit.ignore
1694
+ def get_classifier(self):
1695
+ return self.head.fc
1696
+
1697
+ def reset_classifier(self, num_classes, global_pool=None):
1698
+ self.num_classes = num_classes
1699
+ if global_pool is None:
1700
+ global_pool = self.head.global_pool.pool_type
1701
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
1702
+
1703
+ def forward_features(self, x):
1704
+ x = self.stem(x)
1705
+ x = self.stages(x)
1706
+ x = self.norm(x)
1707
+ return x
1708
+
1709
+ def forward_head(self, x, pre_logits: bool = False):
1710
+ return self.head(x, pre_logits=pre_logits)
1711
+
1712
+ def forward(self, x):
1713
+ x = self.forward_features(x)
1714
+ x = self.forward_head(x)
1715
+ return x
1716
+
1717
+
1718
+ def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs):
1719
+ return build_model_with_cfg(
1720
+ MaxxVit, variant, pretrained,
1721
+ model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
1722
+ feature_cfg=dict(flatten_sequential=True),
1723
+ **kwargs)
1724
+
1725
+
1726
+ @register_model
1727
+ def coatnet_pico_rw_224(pretrained=False, **kwargs):
1728
+ return _create_maxxvit('coatnet_pico_rw_224', pretrained=pretrained, **kwargs)
1729
+
1730
+
1731
+ @register_model
1732
+ def coatnet_nano_rw_224(pretrained=False, **kwargs):
1733
+ return _create_maxxvit('coatnet_nano_rw_224', pretrained=pretrained, **kwargs)
1734
+
1735
+
1736
+ @register_model
1737
+ def coatnet_0_rw_224(pretrained=False, **kwargs):
1738
+ return _create_maxxvit('coatnet_0_rw_224', pretrained=pretrained, **kwargs)
1739
+
1740
+
1741
+ @register_model
1742
+ def coatnet_1_rw_224(pretrained=False, **kwargs):
1743
+ return _create_maxxvit('coatnet_1_rw_224', pretrained=pretrained, **kwargs)
1744
+
1745
+
1746
+ @register_model
1747
+ def coatnet_2_rw_224(pretrained=False, **kwargs):
1748
+ return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs)
1749
+
1750
+
1751
+ @register_model
1752
+ def coatnet_3_rw_224(pretrained=False, **kwargs):
1753
+ return _create_maxxvit('coatnet_3_rw_224', pretrained=pretrained, **kwargs)
1754
+
1755
+
1756
+ @register_model
1757
+ def coatnet_bn_0_rw_224(pretrained=False, **kwargs):
1758
+ return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs)
1759
+
1760
+
1761
+ @register_model
1762
+ def coatnet_rmlp_nano_rw_224(pretrained=False, **kwargs):
1763
+ return _create_maxxvit('coatnet_rmlp_nano_rw_224', pretrained=pretrained, **kwargs)
1764
+
1765
+
1766
+ @register_model
1767
+ def coatnet_rmlp_0_rw_224(pretrained=False, **kwargs):
1768
+ return _create_maxxvit('coatnet_rmlp_0_rw_224', pretrained=pretrained, **kwargs)
1769
+
1770
+
1771
+ @register_model
1772
+ def coatnet_rmlp_1_rw_224(pretrained=False, **kwargs):
1773
+ return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs)
1774
+
1775
+
1776
+ @register_model
1777
+ def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs):
1778
+ return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs)
1779
+
1780
+
1781
+ @register_model
1782
+ def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs):
1783
+ return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs)
1784
+
1785
+
1786
+ @register_model
1787
+ def coatnet_nano_cc_224(pretrained=False, **kwargs):
1788
+ return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs)
1789
+
1790
+
1791
+ @register_model
1792
+ def coatnext_nano_rw_224(pretrained=False, **kwargs):
1793
+ return _create_maxxvit('coatnext_nano_rw_224', pretrained=pretrained, **kwargs)
1794
+
1795
+
1796
+ @register_model
1797
+ def coatnet_0_224(pretrained=False, **kwargs):
1798
+ return _create_maxxvit('coatnet_0_224', pretrained=pretrained, **kwargs)
1799
+
1800
+
1801
+ @register_model
1802
+ def coatnet_1_224(pretrained=False, **kwargs):
1803
+ return _create_maxxvit('coatnet_1_224', pretrained=pretrained, **kwargs)
1804
+
1805
+
1806
+ @register_model
1807
+ def coatnet_2_224(pretrained=False, **kwargs):
1808
+ return _create_maxxvit('coatnet_2_224', pretrained=pretrained, **kwargs)
1809
+
1810
+
1811
+ @register_model
1812
+ def coatnet_3_224(pretrained=False, **kwargs):
1813
+ return _create_maxxvit('coatnet_3_224', pretrained=pretrained, **kwargs)
1814
+
1815
+
1816
+ @register_model
1817
+ def coatnet_4_224(pretrained=False, **kwargs):
1818
+ return _create_maxxvit('coatnet_4_224', pretrained=pretrained, **kwargs)
1819
+
1820
+
1821
+ @register_model
1822
+ def coatnet_5_224(pretrained=False, **kwargs):
1823
+ return _create_maxxvit('coatnet_5_224', pretrained=pretrained, **kwargs)
1824
+
1825
+
1826
+ @register_model
1827
+ def maxvit_pico_rw_256(pretrained=False, **kwargs):
1828
+ return _create_maxxvit('maxvit_pico_rw_256', pretrained=pretrained, **kwargs)
1829
+
1830
+
1831
+ @register_model
1832
+ def maxvit_nano_rw_256(pretrained=False, **kwargs):
1833
+ return _create_maxxvit('maxvit_nano_rw_256', pretrained=pretrained, **kwargs)
1834
+
1835
+
1836
+ @register_model
1837
+ def maxvit_tiny_rw_224(pretrained=False, **kwargs):
1838
+ return _create_maxxvit('maxvit_tiny_rw_224', pretrained=pretrained, **kwargs)
1839
+
1840
+
1841
+ @register_model
1842
+ def maxvit_tiny_rw_256(pretrained=False, **kwargs):
1843
+ return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
1844
+
1845
+
1846
+ @register_model
1847
+ def maxvit_rmlp_pico_rw_256(pretrained=False, **kwargs):
1848
+ return _create_maxxvit('maxvit_rmlp_pico_rw_256', pretrained=pretrained, **kwargs)
1849
+
1850
+
1851
+ @register_model
1852
+ def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs):
1853
+ return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
1854
+
1855
+
1856
+ @register_model
1857
+ def maxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs):
1858
+ return _create_maxxvit('maxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs)
1859
+
1860
+
1861
+ @register_model
1862
+ def maxvit_rmlp_small_rw_224(pretrained=False, **kwargs):
1863
+ return _create_maxxvit('maxvit_rmlp_small_rw_224', pretrained=pretrained, **kwargs)
1864
+
1865
+
1866
+ @register_model
1867
+ def maxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
1868
+ return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
1869
+
1870
+
1871
+ @register_model
1872
+ def maxvit_tiny_pm_256(pretrained=False, **kwargs):
1873
+ return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
1874
+
1875
+
1876
+ @register_model
1877
+ def maxxvit_rmlp_nano_rw_256(pretrained=False, **kwargs):
1878
+ return _create_maxxvit('maxxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
1879
+
1880
+
1881
+ @register_model
1882
+ def maxxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs):
1883
+ return _create_maxxvit('maxxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs)
1884
+
1885
+
1886
+ @register_model
1887
+ def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
1888
+ return _create_maxxvit('maxxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
1889
+
1890
+
1891
+ @register_model
1892
+ def maxvit_tiny_224(pretrained=False, **kwargs):
1893
+ return _create_maxxvit('maxvit_tiny_224', pretrained=pretrained, **kwargs)
1894
+
1895
+
1896
+ @register_model
1897
+ def maxvit_small_224(pretrained=False, **kwargs):
1898
+ return _create_maxxvit('maxvit_small_224', pretrained=pretrained, **kwargs)
1899
+
1900
+
1901
+ @register_model
1902
+ def maxvit_base_224(pretrained=False, **kwargs):
1903
+ return _create_maxxvit('maxvit_base_224', pretrained=pretrained, **kwargs)
1904
+
1905
+
1906
+ @register_model
1907
+ def maxvit_large_224(pretrained=False, **kwargs):
1908
+ return _create_maxxvit('maxvit_large_224', pretrained=pretrained, **kwargs)
1909
+
1910
+
1911
+ @register_model
1912
+ def maxvit_xlarge_224(pretrained=False, **kwargs):
1913
+ return _create_maxxvit('maxvit_xlarge_224', pretrained=pretrained, **kwargs)