|
|
import cv2 |
|
|
import torch |
|
|
from torch import nn |
|
|
from einops.layers.torch import Rearrange |
|
|
from .DCT import Learnable_DCT2D |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
""" ConvNeXtV2 Block. |
|
|
|
|
|
Args: |
|
|
dim (int): Number of input channels. |
|
|
drop_path (float): Stochastic depth rate. Default: 0.0 |
|
|
""" |
|
|
|
|
|
def __init__(self, dim, drop_path=0.): |
|
|
super().__init__() |
|
|
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) |
|
|
self.norm = LayerNorm(dim, eps=1e-6) |
|
|
self.pwconv1 = nn.Linear(dim, 4 * dim) |
|
|
self.act = nn.GELU() |
|
|
self.grn = GRN(4 * dim) |
|
|
self.pwconv2 = nn.Linear(4 * dim, dim) |
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
self.attention = Spatial_Attention() |
|
|
def forward(self, x): |
|
|
input = x |
|
|
x = self.dwconv(x) |
|
|
x = x.permute(0, 2, 3, 1) |
|
|
x = self.norm(x) |
|
|
x = self.pwconv1(x) |
|
|
x = self.act(x) |
|
|
x = self.grn(x) |
|
|
x = self.pwconv2(x) |
|
|
|
|
|
x = x.permute(0, 3, 1, 2) |
|
|
attention = self.attention(x) |
|
|
x = x * nn.UpsamplingBilinear2d(x.shape[2:])(attention) |
|
|
x = input + self.drop_path(x) |
|
|
return x |
|
|
|
|
|
class Spatial_Attention(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.avgpool = nn.AdaptiveAvgPool2d((7,7)) |
|
|
self.conv = nn.Conv2d(2,1, kernel_size=7, padding=3) |
|
|
self.attention = TransformerBlock(1, 1, heads=1, dim_head=1, img_size=[7,7]) |
|
|
|
|
|
def forward(self, x): |
|
|
x_avg = x.mean([1]).unsqueeze(1) |
|
|
x_max = x.max(dim=1).values.unsqueeze(1) |
|
|
|
|
|
x = torch.cat([x_avg, x_max], dim=1) |
|
|
x = self.avgpool(x) |
|
|
x = self.conv(x) |
|
|
x = self.attention(x) |
|
|
return x |
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=False, dropout=0.): |
|
|
super().__init__() |
|
|
hidden_dim = int(inp * 4) |
|
|
|
|
|
self.downsample = downsample |
|
|
self.ih, self.iw = img_size |
|
|
|
|
|
if self.downsample: |
|
|
self.pool1 = nn.MaxPool2d(3, 2, 1) |
|
|
self.pool2 = nn.MaxPool2d(3, 2, 1) |
|
|
self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False) |
|
|
|
|
|
self.attn = Attention(inp, oup, heads, dim_head, dropout) |
|
|
self.ff = FeedForward(oup, hidden_dim, dropout) |
|
|
|
|
|
self.attn = nn.Sequential( |
|
|
Rearrange('b c ih iw -> b (ih iw) c'), |
|
|
PreNorm(inp, self.attn, nn.LayerNorm), |
|
|
Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw) |
|
|
) |
|
|
|
|
|
self.ff = nn.Sequential( |
|
|
Rearrange('b c ih iw -> b (ih iw) c'), |
|
|
PreNorm(oup, self.ff, nn.LayerNorm), |
|
|
Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
if self.downsample: |
|
|
x = self.proj(self.pool1(x)) + self.attn(self.pool2(x)) |
|
|
else: |
|
|
x = x + self.attn(x) |
|
|
x = x + self.ff(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class CSATv2(nn.Module): |
|
|
def __init__(self, img_size=None, num_classes=1000, drop_path_rate=0, head_init_scale=1): |
|
|
super().__init__() |
|
|
dims = [32, 72, 168, 386] |
|
|
channel_order = "channels_first" |
|
|
depths = [2, 2, 6, 4] |
|
|
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.stages1 = nn.Sequential( |
|
|
Block(dim=dims[0], drop_path=dp_rates[0]), |
|
|
Block(dim=dims[0], drop_path=dp_rates[1]), |
|
|
LayerNorm(dims[0], eps=1e-6, data_format=channel_order), |
|
|
nn.Conv2d(dims[0], dims[0 + 1], kernel_size=2, stride=2), |
|
|
) |
|
|
|
|
|
self.stages2 = nn.Sequential( |
|
|
Block(dim=dims[1], drop_path=dp_rates[0]), |
|
|
Block(dim=dims[1], drop_path=dp_rates[1]), |
|
|
LayerNorm(dims[1], eps=1e-6, data_format=channel_order), |
|
|
nn.Conv2d(dims[1], dims[1 + 1], kernel_size=2, stride=2), |
|
|
) |
|
|
|
|
|
self.stages3 = nn.Sequential( |
|
|
Block(dim=dims[2], drop_path=dp_rates[0]), |
|
|
Block(dim=dims[2], drop_path=dp_rates[1]), |
|
|
Block(dim=dims[2], drop_path=dp_rates[2]), |
|
|
Block(dim=dims[2], drop_path=dp_rates[3]), |
|
|
Block(dim=dims[2], drop_path=dp_rates[4]), |
|
|
Block(dim=dims[2], drop_path=dp_rates[5]), |
|
|
TransformerBlock(inp=dims[2], oup=dims[2], img_size=[int(img_size / 32), int(img_size / 32)]), |
|
|
TransformerBlock(inp=dims[2], oup=dims[2], img_size=[int(img_size / 32), int(img_size / 32)]), |
|
|
LayerNorm(dims[2], eps=1e-6, data_format=channel_order), |
|
|
nn.Conv2d(dims[2], dims[2 + 1], kernel_size=2, stride=2), |
|
|
) |
|
|
|
|
|
self.stages4 = nn.Sequential( |
|
|
Block(dim=dims[3], drop_path=dp_rates[0]), |
|
|
Block(dim=dims[3], drop_path=dp_rates[1]), |
|
|
Block(dim=dims[3], drop_path=dp_rates[2]), |
|
|
Block(dim=dims[3], drop_path=dp_rates[3]), |
|
|
TransformerBlock(inp=dims[3], oup=dims[3], img_size=[int(img_size / 64), int(img_size / 64)]), |
|
|
TransformerBlock(inp=dims[3], oup=dims[3], img_size=[int(img_size / 64), int(img_size / 64)]), |
|
|
) |
|
|
|
|
|
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) |
|
|
self.head = nn.Linear(dims[-1], num_classes) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
self.head.weight.data.mul_(head_init_scale) |
|
|
self.head.bias.data.mul_(head_init_scale) |
|
|
self.dct = Learnable_DCT2D(8) |
|
|
|
|
|
|
|
|
def load_checkpoint(self, checkpoint): |
|
|
state = torch.load(checkpoint, map_location='cpu') |
|
|
try: |
|
|
state_dict = state['state_dict'] |
|
|
except: |
|
|
state_dict = state['model'] |
|
|
for key in list(state_dict.keys()): |
|
|
state_dict[key.replace('module.backbone.', '').replace('resnet.', '')] = state_dict.pop(key) |
|
|
|
|
|
model_dict = self.state_dict() |
|
|
weights = {k: v for k, v in state_dict.items() if k in model_dict} |
|
|
|
|
|
model_dict.update(weights) |
|
|
del model_dict['head.bias'] |
|
|
del model_dict['head.weight'] |
|
|
self.load_state_dict(model_dict, strict=False) |
|
|
|
|
|
def preprocess(self, x): |
|
|
x = cv2.cvtColor(x, cv2.COLOR_BGR2YCR_CB) |
|
|
return x |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, (nn.Conv2d, nn.Linear)): |
|
|
trunc_normal_(m.weight, std=.02) |
|
|
try: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = self.dct(x) |
|
|
x = self.stages1(x) |
|
|
x = self.stages2(x) |
|
|
x = self.stages3(x) |
|
|
x = self.stages4(x) |
|
|
x = self.norm(x.mean([-2, -1])) |
|
|
x = self.head(x) |
|
|
return x |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
import math |
|
|
import warnings |
|
|
|
|
|
class LayerNorm(nn.Module): |
|
|
""" LayerNorm that supports two data formats: channels_last (default) or channels_first. |
|
|
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with |
|
|
shape (batch_size, height, width, channels) while channels_first corresponds to inputs |
|
|
with shape (batch_size, channels, height, width). |
|
|
""" |
|
|
|
|
|
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(normalized_shape)) |
|
|
self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
|
|
self.eps = eps |
|
|
self.data_format = data_format |
|
|
if self.data_format not in ["channels_last", "channels_first"]: |
|
|
raise NotImplementedError |
|
|
self.normalized_shape = (normalized_shape,) |
|
|
|
|
|
def forward(self, x): |
|
|
if self.data_format == "channels_last": |
|
|
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
|
|
elif self.data_format == "channels_first": |
|
|
u = x.mean(1, keepdim=True) |
|
|
s = (x - u).pow(2).mean(1, keepdim=True) |
|
|
x = (x - u) / torch.sqrt(s + self.eps) |
|
|
x = self.weight[:, None, None] * x + self.bias[:, None, None] |
|
|
return x |
|
|
|
|
|
|
|
|
class GRN(nn.Module): |
|
|
""" GRN (Global Response Normalization) layer |
|
|
""" |
|
|
|
|
|
def __init__(self, dim): |
|
|
super().__init__() |
|
|
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
|
|
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) |
|
|
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) |
|
|
return self.gamma * (x * Nx) + self.beta + x |
|
|
|
|
|
def drop_path(x, drop_prob: float = 0., training: bool = False): |
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
|
|
|
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, |
|
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
|
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for |
|
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use |
|
|
'survival rate' as the argument. |
|
|
|
|
|
""" |
|
|
if drop_prob == 0. or not training: |
|
|
return x |
|
|
keep_prob = 1 - drop_prob |
|
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
|
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
|
|
random_tensor.floor_() |
|
|
output = x.div(keep_prob) * random_tensor |
|
|
return output |
|
|
|
|
|
|
|
|
class DropPath(nn.Module): |
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
|
""" |
|
|
def __init__(self, drop_prob=None): |
|
|
super(DropPath, self).__init__() |
|
|
self.drop_prob = drop_prob |
|
|
|
|
|
def forward(self, x): |
|
|
return drop_path(x, self.drop_prob, self.training) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
def __init__(self, dim, hidden_dim, dropout=0.): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(dim, hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim, dim), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
class PreNorm(nn.Module): |
|
|
def __init__(self, dim, fn, norm): |
|
|
super().__init__() |
|
|
self.norm = norm(dim) |
|
|
self.fn = fn |
|
|
|
|
|
def forward(self, x, **kwargs): |
|
|
return self.fn(self.norm(x), **kwargs) |
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0.): |
|
|
super().__init__() |
|
|
inner_dim = dim_head * heads |
|
|
project_out = not (heads == 1 and dim_head == inp) |
|
|
|
|
|
|
|
|
self.heads = heads |
|
|
self.scale = dim_head ** -0.5 |
|
|
|
|
|
self.attend = nn.Softmax(dim=-1) |
|
|
self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False) |
|
|
|
|
|
self.to_out = nn.Sequential( |
|
|
nn.Linear(inner_dim, oup), |
|
|
nn.Dropout(dropout) |
|
|
) if project_out else nn.Identity() |
|
|
self.pos_embed = PosCNN(in_chans=inp) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.pos_embed(x) |
|
|
qkv = self.to_qkv(x).chunk(3, dim=-1) |
|
|
q, k, v = map(lambda t: rearrange( |
|
|
t, 'b n (h d) -> b h n d', h=self.heads), qkv) |
|
|
|
|
|
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
|
|
attn = self.attend(dots) |
|
|
out = torch.matmul(attn, v) |
|
|
out = rearrange(out, 'b h n d -> b n (h d)') |
|
|
out = self.to_out(out) |
|
|
return out |
|
|
|
|
|
|
|
|
class PosCNN(nn.Module): |
|
|
def __init__(self, in_chans): |
|
|
super(PosCNN, self).__init__() |
|
|
self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride = 1, padding=1, bias=True, groups=in_chans) |
|
|
|
|
|
def forward(self, x): |
|
|
B, N, C = x.shape |
|
|
feat_token = x |
|
|
H, W = int(N**0.5), int(N**0.5) |
|
|
cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) |
|
|
x = self.proj(cnn_feat) + cnn_feat |
|
|
x = x.flatten(2).transpose(1, 2) |
|
|
return x |
|
|
|
|
|
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
|
|
|
|
|
r"""Fills the input Tensor with values drawn from a truncated |
|
|
normal distribution. The values are effectively drawn from the |
|
|
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` |
|
|
with values outside :math:`[a, b]` redrawn until they are within |
|
|
the bounds. The method used for generating the random values works |
|
|
best when :math:`a \leq \text{mean} \leq b`. |
|
|
Args: |
|
|
tensor: an n-dimensional `torch.Tensor` |
|
|
mean: the mean of the normal distribution |
|
|
std: the standard deviation of the normal distribution |
|
|
a: the minimum cutoff value |
|
|
b: the maximum cutoff value |
|
|
Examples: |
|
|
>>> w = torch.empty(3, 5) |
|
|
>>> nn.init.trunc_normal_(w) |
|
|
""" |
|
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b) |
|
|
|
|
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b): |
|
|
|
|
|
|
|
|
def norm_cdf(x): |
|
|
|
|
|
return (1. + math.erf(x / math.sqrt(2.))) / 2. |
|
|
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std): |
|
|
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
|
|
"The distribution of values may be incorrect.", |
|
|
stacklevel=2) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
l = norm_cdf((a - mean) / std) |
|
|
u = norm_cdf((b - mean) / std) |
|
|
|
|
|
|
|
|
|
|
|
tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
|
|
|
|
|
|
|
|
|
tensor.erfinv_() |
|
|
|
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.)) |
|
|
tensor.add_(mean) |
|
|
|
|
|
|
|
|
tensor.clamp_(min=a, max=b) |
|
|
return tensor |