|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights |
|
|
|
|
|
|
|
|
|
|
|
class PositionAttentionModule(nn.Module): |
|
|
def __init__(self, in_channels): |
|
|
super().__init__() |
|
|
mid = max(1, in_channels // 8) |
|
|
self.conv_b = nn.Conv2d(in_channels, mid, 1, bias=False) |
|
|
self.conv_c = nn.Conv2d(in_channels, mid, 1, bias=False) |
|
|
self.conv_d = nn.Conv2d(in_channels, in_channels, 1, bias=False) |
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
def forward(self, x): |
|
|
b, c, h, w = x.size() |
|
|
B = self.conv_b(x).view(b, -1, h*w) |
|
|
C = self.conv_c(x).view(b, -1, h*w).permute(0, 2, 1) |
|
|
D = self.conv_d(x).view(b, -1, h*w) |
|
|
A = self.softmax(torch.bmm(C, B)) |
|
|
out = torch.bmm(D, A.permute(0, 2, 1)).view(b, c, h, w) |
|
|
return out + x |
|
|
|
|
|
|
|
|
class ChannelAttentionModule(nn.Module): |
|
|
def __init__(self, in_channels): |
|
|
super().__init__() |
|
|
self.gamma = nn.Parameter(torch.zeros(1)) |
|
|
|
|
|
def forward(self, x): |
|
|
b, c, h, w = x.size() |
|
|
xf = x.view(b, c, -1) |
|
|
A = torch.bmm(xf, xf.permute(0, 2, 1)) |
|
|
A = F.softmax(A, dim=-1) |
|
|
out = torch.bmm(A, xf).view(b, c, h, w) |
|
|
return self.gamma * out + x |
|
|
|
|
|
|
|
|
class DANet(nn.Module): |
|
|
def __init__(self, in_channels): |
|
|
super().__init__() |
|
|
self.pam = PositionAttentionModule(in_channels) |
|
|
self.cam = ChannelAttentionModule(in_channels) |
|
|
self.fuse = nn.Conv2d(in_channels, in_channels, 1, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.fuse(self.pam(x) + self.cam(x)) |
|
|
|
|
|
|
|
|
class MSFE(nn.Module): |
|
|
def __init__(self, in_channels, branch_out=None): |
|
|
super().__init__() |
|
|
if branch_out is None: |
|
|
branch_out = max(1, in_channels // 3) |
|
|
self.b1 = nn.Conv2d(in_channels, branch_out, 1, bias=False) |
|
|
self.b2 = nn.Conv2d(in_channels, branch_out, 3, padding=1, bias=False) |
|
|
self.b3 = nn.Conv2d(in_channels, branch_out, 5, padding=2, bias=False) |
|
|
self.bn = nn.BatchNorm2d(branch_out * 3) |
|
|
self.proj = nn.Conv2d(branch_out * 3, in_channels, 1, bias=False) |
|
|
self.act = nn.ReLU(inplace=True) |
|
|
|
|
|
def forward(self, x): |
|
|
y = torch.cat([self.b1(x), self.b2(x), self.b3(x)], dim=1) |
|
|
y = self.proj(self.act(self.bn(y))) |
|
|
return y |
|
|
|
|
|
|
|
|
class PreNorm(nn.Module): |
|
|
def __init__(self, c): |
|
|
super().__init__() |
|
|
self.bn = nn.BatchNorm2d(c) |
|
|
self.act = nn.ReLU(inplace=True) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.act(self.bn(x)) |
|
|
|
|
|
|
|
|
class MSFE_then_DANet_Block(nn.Module): |
|
|
def __init__(self, c, branch_out=None, drop_p=0.0): |
|
|
super().__init__() |
|
|
self.pre = PreNorm(c) |
|
|
self.msfe = MSFE(c, branch_out=branch_out) |
|
|
self.danet = DANet(c) |
|
|
self.dp = nn.Dropout2d(drop_p) if drop_p > 0 else nn.Identity() |
|
|
|
|
|
def forward(self, x): |
|
|
y = self.pre(x) |
|
|
y = self.msfe(y) |
|
|
y = self.danet(y) |
|
|
y = self.dp(y) |
|
|
return x + y |
|
|
|
|
|
|
|
|
class DANet_Block(nn.Module): |
|
|
def __init__(self, c, drop_p=0.0): |
|
|
super().__init__() |
|
|
self.pre = PreNorm(c) |
|
|
self.danet = DANet(c) |
|
|
self.dp = nn.Dropout2d(drop_p) if drop_p > 0 else nn.Identity() |
|
|
|
|
|
def forward(self, x): |
|
|
y = self.dp(self.danet(self.pre(x))) |
|
|
return x + y |
|
|
|
|
|
|
|
|
|
|
|
class EfficientNetB0Hybrid(nn.Module): |
|
|
def __init__(self, num_classes=7, msfe_then_danet_indices=(6,), danet_only_indices=(4,), |
|
|
branch_out_ratio=1/3, drop_p=0.0, use_pretrained=True): |
|
|
super().__init__() |
|
|
self.backbone = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT if use_pretrained else None) |
|
|
self.features = self.backbone.features |
|
|
self.pool = nn.AdaptiveAvgPool2d(1) |
|
|
self.classifier = nn.Linear(self.backbone.classifier[1].in_features, num_classes) |
|
|
|
|
|
|
|
|
self._feat_shapes = {} |
|
|
self._register_shape_hooks(msfe_then_danet_indices, danet_only_indices) |
|
|
with torch.no_grad(): |
|
|
dummy = torch.zeros(1, 3, 224, 224) |
|
|
_ = self._run_features_with_capture(dummy) |
|
|
for h in self._hooks: h.remove() |
|
|
|
|
|
self.msfe_then_danet_blocks = nn.ModuleDict() |
|
|
self.danet_only_blocks = nn.ModuleDict() |
|
|
|
|
|
for idx in msfe_then_danet_indices: |
|
|
C = self._feat_shapes[idx][1] |
|
|
br = max(1, int(C * branch_out_ratio)) |
|
|
self.msfe_then_danet_blocks[str(idx)] = MSFE_then_DANet_Block(C, branch_out=br, drop_p=drop_p) |
|
|
|
|
|
for idx in danet_only_indices: |
|
|
C = self._feat_shapes[idx][1] |
|
|
self.danet_only_blocks[str(idx)] = DANet_Block(C, drop_p=drop_p) |
|
|
|
|
|
self.msfe_then_danet_indices = set(msfe_then_danet_indices) |
|
|
self.danet_only_indices = set(danet_only_indices) |
|
|
|
|
|
def _register_shape_hooks(self, msfe_then_danet_indices, danet_only_indices): |
|
|
self._hooks = [] |
|
|
wanted = set(msfe_then_danet_indices) | set(danet_only_indices) |
|
|
|
|
|
def make_hook(i): |
|
|
def hook(_m, _inp, out): |
|
|
if isinstance(out, (list, tuple)): |
|
|
out = out[0] |
|
|
self._feat_shapes[i] = out.shape |
|
|
return hook |
|
|
|
|
|
for i, m in enumerate(self.features): |
|
|
if i in wanted: |
|
|
self._hooks.append(m.register_forward_hook(make_hook(i))) |
|
|
|
|
|
def _run_features_with_capture(self, x): |
|
|
for m in self.features: |
|
|
x = m(x) |
|
|
return x |
|
|
|
|
|
def forward(self, x): |
|
|
for i, m in enumerate(self.features): |
|
|
x = m(x) |
|
|
if i in self.msfe_then_danet_indices: |
|
|
x = self.msfe_then_danet_blocks[str(i)](x) |
|
|
if i in self.danet_only_indices: |
|
|
x = self.danet_only_blocks[str(i)](x) |
|
|
x = self.pool(x).flatten(1) |
|
|
return self.classifier(x) |
|
|
|