File size: 16,438 Bytes
fb56b04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 |
import torch
import itertools
from timm.models.vision_transformer import trunc_normal_
from timm.models.layers import SqueezeExcite
from timm.models.registry import register_model
class Conv2d_BN(torch.nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
groups=1, bn_weight_init=1, resolution=-10000):
super().__init__()
self.add_module('c', torch.nn.Conv2d(
a, b, ks, stride, pad, dilation, groups, bias=False))
self.add_module('bn', torch.nn.BatchNorm2d(b))
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
torch.nn.init.constant_(self.bn.bias, 0)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps)**0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class BN_Linear(torch.nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
self.add_module('bn', torch.nn.BatchNorm1d(a))
self.add_module('l', torch.nn.Linear(a, b, bias=bias))
trunc_normal_(self.l.weight, std=std)
if bias:
torch.nn.init.constant_(self.l.bias, 0)
@torch.no_grad()
def fuse(self):
bn, l = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps)**0.5
b = bn.bias - self.bn.running_mean * \
self.bn.weight / (bn.running_var + bn.eps)**0.5
w = l.weight * w[None, :]
if l.bias is None:
b = b @ self.l.weight.T
else:
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
m = torch.nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class PatchMerging(torch.nn.Module):
def __init__(self, dim, out_dim, input_resolution):
super().__init__()
hid_dim = int(dim * 4)
self.conv1 = Conv2d_BN(dim, hid_dim, 1, 1, 0, resolution=input_resolution)
self.act = torch.nn.ReLU()
self.conv2 = Conv2d_BN(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim, resolution=input_resolution)
self.se = SqueezeExcite(hid_dim, .25)
self.conv3 = Conv2d_BN(hid_dim, out_dim, 1, 1, 0, resolution=input_resolution // 2)
def forward(self, x):
x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
return x
class Residual(torch.nn.Module):
def __init__(self, m, drop=0.):
super().__init__()
self.m = m
self.drop = drop
def forward(self, x):
if self.training and self.drop > 0:
return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
device=x.device).ge_(self.drop).div(1 - self.drop).detach()
else:
return x + self.m(x)
class FFN(torch.nn.Module):
def __init__(self, ed, h, resolution):
super().__init__()
self.pw1 = Conv2d_BN(ed, h, resolution=resolution)
self.act = torch.nn.ReLU()
self.pw2 = Conv2d_BN(h, ed, bn_weight_init=0, resolution=resolution)
def forward(self, x):
x = self.pw2(self.act(self.pw1(x)))
return x
class CascadedGroupAttention(torch.nn.Module):
r""" Cascaded Group Attention.
Args:
dim (int): Number of input channels.
key_dim (int): The dimension for query and key.
num_heads (int): Number of attention heads.
attn_ratio (int): Multiplier for the query dim for value dimension.
resolution (int): Input resolution, correspond to the window size.
kernels (List[int]): The kernel size of the dw conv on query.
"""
def __init__(self, dim, key_dim, num_heads=8,
attn_ratio=4,
resolution=14,
kernels=[5, 5, 5, 5],):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.d = int(attn_ratio * key_dim)
self.attn_ratio = attn_ratio
qkvs = []
dws = []
for i in range(num_heads):
qkvs.append(Conv2d_BN(dim // (num_heads), self.key_dim * 2 + self.d, resolution=resolution))
dws.append(Conv2d_BN(self.key_dim, self.key_dim, kernels[i], 1, kernels[i]//2, groups=self.key_dim, resolution=resolution))
self.qkvs = torch.nn.ModuleList(qkvs)
self.dws = torch.nn.ModuleList(dws)
self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN(
self.d * num_heads, dim, bn_weight_init=0, resolution=resolution))
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(
torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs',
torch.LongTensor(idxs).view(N, N))
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # x (B,C,H,W)
B, C, H, W = x.shape
trainingab = self.attention_biases[:, self.attention_bias_idxs]
feats_in = x.chunk(len(self.qkvs), dim=1)
feats_out = []
feat = feats_in[0]
for i, qkv in enumerate(self.qkvs):
if i > 0: # add the previous output to the input
feat = feat + feats_in[i]
feat = qkv(feat)
q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1) # B, C/h, H, W
q = self.dws[i](q)
q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) # B, C/h, N
attn = (
(q.transpose(-2, -1) @ k) * self.scale
+
(trainingab[i] if self.training else self.ab[i])
)
attn = attn.softmax(dim=-1) # BNN
feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W) # BCHW
feats_out.append(feat)
x = self.proj(torch.cat(feats_out, 1))
return x
class LocalWindowAttention(torch.nn.Module):
r""" Local Window Attention.
Args:
dim (int): Number of input channels.
key_dim (int): The dimension for query and key.
num_heads (int): Number of attention heads.
attn_ratio (int): Multiplier for the query dim for value dimension.
resolution (int): Input resolution.
window_resolution (int): Local window resolution.
kernels (List[int]): The kernel size of the dw conv on query.
"""
def __init__(self, dim, key_dim, num_heads=8,
attn_ratio=4,
resolution=14,
window_resolution=7,
kernels=[5, 5, 5, 5],):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.resolution = resolution
assert window_resolution > 0, 'window_size must be greater than 0'
self.window_resolution = window_resolution
window_resolution = min(window_resolution, resolution)
self.attn = CascadedGroupAttention(dim, key_dim, num_heads,
attn_ratio=attn_ratio,
resolution=window_resolution,
kernels=kernels,)
def forward(self, x):
H = W = self.resolution
B, C, H_, W_ = x.shape
# Only check this for classifcation models
assert H == H_ and W == W_, 'input feature has wrong size, expect {}, got {}'.format((H, W), (H_, W_))
if H <= self.window_resolution and W <= self.window_resolution:
x = self.attn(x)
else:
x = x.permute(0, 2, 3, 1)
pad_b = (self.window_resolution - H %
self.window_resolution) % self.window_resolution
pad_r = (self.window_resolution - W %
self.window_resolution) % self.window_resolution
padding = pad_b > 0 or pad_r > 0
if padding:
x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b))
pH, pW = H + pad_b, W + pad_r
nH = pH // self.window_resolution
nW = pW // self.window_resolution
# window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw
x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3).reshape(
B * nH * nW, self.window_resolution, self.window_resolution, C
).permute(0, 3, 1, 2)
x = self.attn(x)
# window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC
x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution,
C).transpose(2, 3).reshape(B, pH, pW, C)
if padding:
x = x[:, :H, :W].contiguous()
x = x.permute(0, 3, 1, 2)
return x
class EfficientViTBlock(torch.nn.Module):
""" A basic EfficientViT building block.
Args:
type (str): Type for token mixer. Default: 's' for self-attention.
ed (int): Number of input channels.
kd (int): Dimension for query and key in the token mixer.
nh (int): Number of attention heads.
ar (int): Multiplier for the query dim for value dimension.
resolution (int): Input resolution.
window_resolution (int): Local window resolution.
kernels (List[int]): The kernel size of the dw conv on query.
"""
def __init__(self, type,
ed, kd, nh=8,
ar=4,
resolution=14,
window_resolution=7,
kernels=[5, 5, 5, 5],):
super().__init__()
self.dw0 = Residual(Conv2d_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0., resolution=resolution))
self.ffn0 = Residual(FFN(ed, int(ed * 2), resolution))
if type == 's':
self.mixer = Residual(LocalWindowAttention(ed, kd, nh, attn_ratio=ar, \
resolution=resolution, window_resolution=window_resolution, kernels=kernels))
self.dw1 = Residual(Conv2d_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0., resolution=resolution))
self.ffn1 = Residual(FFN(ed, int(ed * 2), resolution))
def forward(self, x):
return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
class EfficientViT(torch.nn.Module):
def __init__(self, img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
stages=['s', 's', 's'],
embed_dim=[64, 128, 192],
key_dim=[16, 16, 16],
depth=[1, 2, 3],
num_heads=[4, 4, 4],
window_size=[7, 7, 7],
kernels=[5, 5, 5, 5],
down_ops=[['subsample', 2], ['subsample', 2], ['']],
distillation=False,):
super().__init__()
resolution = img_size
# Patch embedding
self.patch_embed = torch.nn.Sequential(Conv2d_BN(in_chans, embed_dim[0] // 8, 3, 2, 1, resolution=resolution), torch.nn.ReLU(),
Conv2d_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1, resolution=resolution // 2), torch.nn.ReLU(),
Conv2d_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1, resolution=resolution // 4), torch.nn.ReLU(),
Conv2d_BN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1, resolution=resolution // 8))
resolution = img_size // patch_size
attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]
self.blocks1 = []
self.blocks2 = []
self.blocks3 = []
# Build EfficientViT blocks
for i, (stg, ed, kd, dpth, nh, ar, wd, do) in enumerate(
zip(stages, embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
for d in range(dpth):
eval('self.blocks' + str(i+1)).append(EfficientViTBlock(stg, ed, kd, nh, ar, resolution, wd, kernels))
if do[0] == 'subsample':
# Build EfficientViT downsample block
#('Subsample' stride)
blk = eval('self.blocks' + str(i+2))
resolution_ = (resolution - 1) // do[1] + 1
blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i], resolution=resolution)),
Residual(FFN(embed_dim[i], int(embed_dim[i] * 2), resolution)),))
blk.append(PatchMerging(*embed_dim[i:i + 2], resolution))
resolution = resolution_
blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1], resolution=resolution)),
Residual(FFN(embed_dim[i + 1], int(embed_dim[i + 1] * 2), resolution)),))
self.blocks1 = torch.nn.Sequential(*self.blocks1)
self.blocks2 = torch.nn.Sequential(*self.blocks2)
self.blocks3 = torch.nn.Sequential(*self.blocks3)
# Classification head
self.head = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
self.distillation = distillation
if distillation:
self.head_dist = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
@torch.jit.ignore
def no_weight_decay(self):
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
def forward(self, x):
x = self.patch_embed(x)
x = self.blocks1(x)
x = self.blocks2(x)
x = self.blocks3(x)
x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
if self.distillation:
x = self.head(x), self.head_dist(x)
if not self.training:
x = (x[0] + x[1]) / 2
else:
x = self.head(x)
return x
EfficientViT_d = {
'img_size': 224,
'patch_size': 16,
'embed_dim': [96, 144, 400], #192, 288, 384
'depth': [1, 3, 4], #1, 3, 4 -----------------[1, 1, 2]
'num_heads': [3, 3, 4], #3, 3, 4
'window_size': [7, 7, 7],
'kernels': [7, 5, 3, 3],
}
EfficientViT_w = {
'img_size': 224,
'patch_size': 16,
'embed_dim': [192, 288, 96], #400 192
'depth': [1, 1, 1], #1, 3, 4 -----------------[1, 1, 2]
'num_heads': [3, 3, 4], #3, 3, 4
'window_size': [7, 7, 7],
'kernels': [7, 5, 3, 3],
}
@register_model
def EfficientViT_d(num_classes=5, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_d):
model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg)
if fuse:
replace_batchnorm(model)
return model
@register_model
def EfficientViT_w(num_classes=5, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_w):
model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg)
if fuse:
replace_batchnorm(model)
return model
def replace_batchnorm(net):
for child_name, child in net.named_children():
if hasattr(child, 'fuse'):
setattr(net, child_name, child.fuse())
elif isinstance(child, torch.nn.BatchNorm2d):
setattr(net, child_name, torch.nn.Identity())
else:
replace_batchnorm(child)
|