import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights # ---------------------- Attention Modules ---------------------- class PositionAttentionModule(nn.Module): def __init__(self, in_channels): super().__init__() mid = max(1, in_channels // 8) self.conv_b = nn.Conv2d(in_channels, mid, 1, bias=False) self.conv_c = nn.Conv2d(in_channels, mid, 1, bias=False) self.conv_d = nn.Conv2d(in_channels, in_channels, 1, bias=False) self.softmax = nn.Softmax(dim=-1) def forward(self, x): b, c, h, w = x.size() B = self.conv_b(x).view(b, -1, h*w) C = self.conv_c(x).view(b, -1, h*w).permute(0, 2, 1) D = self.conv_d(x).view(b, -1, h*w) A = self.softmax(torch.bmm(C, B)) out = torch.bmm(D, A.permute(0, 2, 1)).view(b, c, h, w) return out + x class ChannelAttentionModule(nn.Module): def __init__(self, in_channels): super().__init__() self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): b, c, h, w = x.size() xf = x.view(b, c, -1) A = torch.bmm(xf, xf.permute(0, 2, 1)) A = F.softmax(A, dim=-1) out = torch.bmm(A, xf).view(b, c, h, w) return self.gamma * out + x class DANet(nn.Module): def __init__(self, in_channels): super().__init__() self.pam = PositionAttentionModule(in_channels) self.cam = ChannelAttentionModule(in_channels) self.fuse = nn.Conv2d(in_channels, in_channels, 1, bias=False) def forward(self, x): return self.fuse(self.pam(x) + self.cam(x)) class MSFE(nn.Module): def __init__(self, in_channels, branch_out=None): super().__init__() if branch_out is None: branch_out = max(1, in_channels // 3) self.b1 = nn.Conv2d(in_channels, branch_out, 1, bias=False) self.b2 = nn.Conv2d(in_channels, branch_out, 3, padding=1, bias=False) self.b3 = nn.Conv2d(in_channels, branch_out, 5, padding=2, bias=False) self.bn = nn.BatchNorm2d(branch_out * 3) self.proj = nn.Conv2d(branch_out * 3, in_channels, 1, bias=False) self.act = nn.ReLU(inplace=True) def forward(self, x): y = torch.cat([self.b1(x), self.b2(x), self.b3(x)], dim=1) y = self.proj(self.act(self.bn(y))) return y class PreNorm(nn.Module): def __init__(self, c): super().__init__() self.bn = nn.BatchNorm2d(c) self.act = nn.ReLU(inplace=True) def forward(self, x): return self.act(self.bn(x)) class MSFE_then_DANet_Block(nn.Module): def __init__(self, c, branch_out=None, drop_p=0.0): super().__init__() self.pre = PreNorm(c) self.msfe = MSFE(c, branch_out=branch_out) self.danet = DANet(c) self.dp = nn.Dropout2d(drop_p) if drop_p > 0 else nn.Identity() def forward(self, x): y = self.pre(x) y = self.msfe(y) y = self.danet(y) y = self.dp(y) return x + y class DANet_Block(nn.Module): def __init__(self, c, drop_p=0.0): super().__init__() self.pre = PreNorm(c) self.danet = DANet(c) self.dp = nn.Dropout2d(drop_p) if drop_p > 0 else nn.Identity() def forward(self, x): y = self.dp(self.danet(self.pre(x))) return x + y # ---------------------- Main Model ---------------------- class EfficientNetB0Hybrid(nn.Module): def __init__(self, num_classes=7, msfe_then_danet_indices=(6,), danet_only_indices=(4,), branch_out_ratio=1/3, drop_p=0.0, use_pretrained=True): super().__init__() self.backbone = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT if use_pretrained else None) self.features = self.backbone.features self.pool = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Linear(self.backbone.classifier[1].in_features, num_classes) # Capture feature shapes self._feat_shapes = {} self._register_shape_hooks(msfe_then_danet_indices, danet_only_indices) with torch.no_grad(): dummy = torch.zeros(1, 3, 224, 224) _ = self._run_features_with_capture(dummy) for h in self._hooks: h.remove() self.msfe_then_danet_blocks = nn.ModuleDict() self.danet_only_blocks = nn.ModuleDict() for idx in msfe_then_danet_indices: C = self._feat_shapes[idx][1] br = max(1, int(C * branch_out_ratio)) self.msfe_then_danet_blocks[str(idx)] = MSFE_then_DANet_Block(C, branch_out=br, drop_p=drop_p) for idx in danet_only_indices: C = self._feat_shapes[idx][1] self.danet_only_blocks[str(idx)] = DANet_Block(C, drop_p=drop_p) self.msfe_then_danet_indices = set(msfe_then_danet_indices) self.danet_only_indices = set(danet_only_indices) def _register_shape_hooks(self, msfe_then_danet_indices, danet_only_indices): self._hooks = [] wanted = set(msfe_then_danet_indices) | set(danet_only_indices) def make_hook(i): def hook(_m, _inp, out): if isinstance(out, (list, tuple)): out = out[0] self._feat_shapes[i] = out.shape return hook for i, m in enumerate(self.features): if i in wanted: self._hooks.append(m.register_forward_hook(make_hook(i))) def _run_features_with_capture(self, x): for m in self.features: x = m(x) return x def forward(self, x): for i, m in enumerate(self.features): x = m(x) if i in self.msfe_then_danet_indices: x = self.msfe_then_danet_blocks[str(i)](x) if i in self.danet_only_indices: x = self.danet_only_blocks[str(i)](x) x = self.pool(x).flatten(1) return self.classifier(x)