Prasanta4's picture
Create model.py
d419456 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
# ---------------------- Attention Modules ----------------------
class PositionAttentionModule(nn.Module):
def __init__(self, in_channels):
super().__init__()
mid = max(1, in_channels // 8)
self.conv_b = nn.Conv2d(in_channels, mid, 1, bias=False)
self.conv_c = nn.Conv2d(in_channels, mid, 1, bias=False)
self.conv_d = nn.Conv2d(in_channels, in_channels, 1, bias=False)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
b, c, h, w = x.size()
B = self.conv_b(x).view(b, -1, h*w)
C = self.conv_c(x).view(b, -1, h*w).permute(0, 2, 1)
D = self.conv_d(x).view(b, -1, h*w)
A = self.softmax(torch.bmm(C, B))
out = torch.bmm(D, A.permute(0, 2, 1)).view(b, c, h, w)
return out + x
class ChannelAttentionModule(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
b, c, h, w = x.size()
xf = x.view(b, c, -1)
A = torch.bmm(xf, xf.permute(0, 2, 1))
A = F.softmax(A, dim=-1)
out = torch.bmm(A, xf).view(b, c, h, w)
return self.gamma * out + x
class DANet(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.pam = PositionAttentionModule(in_channels)
self.cam = ChannelAttentionModule(in_channels)
self.fuse = nn.Conv2d(in_channels, in_channels, 1, bias=False)
def forward(self, x):
return self.fuse(self.pam(x) + self.cam(x))
class MSFE(nn.Module):
def __init__(self, in_channels, branch_out=None):
super().__init__()
if branch_out is None:
branch_out = max(1, in_channels // 3)
self.b1 = nn.Conv2d(in_channels, branch_out, 1, bias=False)
self.b2 = nn.Conv2d(in_channels, branch_out, 3, padding=1, bias=False)
self.b3 = nn.Conv2d(in_channels, branch_out, 5, padding=2, bias=False)
self.bn = nn.BatchNorm2d(branch_out * 3)
self.proj = nn.Conv2d(branch_out * 3, in_channels, 1, bias=False)
self.act = nn.ReLU(inplace=True)
def forward(self, x):
y = torch.cat([self.b1(x), self.b2(x), self.b3(x)], dim=1)
y = self.proj(self.act(self.bn(y)))
return y
class PreNorm(nn.Module):
def __init__(self, c):
super().__init__()
self.bn = nn.BatchNorm2d(c)
self.act = nn.ReLU(inplace=True)
def forward(self, x):
return self.act(self.bn(x))
class MSFE_then_DANet_Block(nn.Module):
def __init__(self, c, branch_out=None, drop_p=0.0):
super().__init__()
self.pre = PreNorm(c)
self.msfe = MSFE(c, branch_out=branch_out)
self.danet = DANet(c)
self.dp = nn.Dropout2d(drop_p) if drop_p > 0 else nn.Identity()
def forward(self, x):
y = self.pre(x)
y = self.msfe(y)
y = self.danet(y)
y = self.dp(y)
return x + y
class DANet_Block(nn.Module):
def __init__(self, c, drop_p=0.0):
super().__init__()
self.pre = PreNorm(c)
self.danet = DANet(c)
self.dp = nn.Dropout2d(drop_p) if drop_p > 0 else nn.Identity()
def forward(self, x):
y = self.dp(self.danet(self.pre(x)))
return x + y
# ---------------------- Main Model ----------------------
class EfficientNetB0Hybrid(nn.Module):
def __init__(self, num_classes=7, msfe_then_danet_indices=(6,), danet_only_indices=(4,),
branch_out_ratio=1/3, drop_p=0.0, use_pretrained=True):
super().__init__()
self.backbone = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT if use_pretrained else None)
self.features = self.backbone.features
self.pool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(self.backbone.classifier[1].in_features, num_classes)
# Capture feature shapes
self._feat_shapes = {}
self._register_shape_hooks(msfe_then_danet_indices, danet_only_indices)
with torch.no_grad():
dummy = torch.zeros(1, 3, 224, 224)
_ = self._run_features_with_capture(dummy)
for h in self._hooks: h.remove()
self.msfe_then_danet_blocks = nn.ModuleDict()
self.danet_only_blocks = nn.ModuleDict()
for idx in msfe_then_danet_indices:
C = self._feat_shapes[idx][1]
br = max(1, int(C * branch_out_ratio))
self.msfe_then_danet_blocks[str(idx)] = MSFE_then_DANet_Block(C, branch_out=br, drop_p=drop_p)
for idx in danet_only_indices:
C = self._feat_shapes[idx][1]
self.danet_only_blocks[str(idx)] = DANet_Block(C, drop_p=drop_p)
self.msfe_then_danet_indices = set(msfe_then_danet_indices)
self.danet_only_indices = set(danet_only_indices)
def _register_shape_hooks(self, msfe_then_danet_indices, danet_only_indices):
self._hooks = []
wanted = set(msfe_then_danet_indices) | set(danet_only_indices)
def make_hook(i):
def hook(_m, _inp, out):
if isinstance(out, (list, tuple)):
out = out[0]
self._feat_shapes[i] = out.shape
return hook
for i, m in enumerate(self.features):
if i in wanted:
self._hooks.append(m.register_forward_hook(make_hook(i)))
def _run_features_with_capture(self, x):
for m in self.features:
x = m(x)
return x
def forward(self, x):
for i, m in enumerate(self.features):
x = m(x)
if i in self.msfe_then_danet_indices:
x = self.msfe_then_danet_blocks[str(i)](x)
if i in self.danet_only_indices:
x = self.danet_only_blocks[str(i)](x)
x = self.pool(x).flatten(1)
return self.classifier(x)