import torch import torch.nn as nn from torchvision import models class ChannelAttention(nn.Module): def __init__(self, channels: int, reduction: int = 16): super().__init__() hidden = max(channels // reduction, 16) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.mlp = nn.Sequential( nn.Conv2d(channels, hidden, kernel_size=1, bias=False), nn.ReLU(inplace=True), nn.Conv2d(hidden, channels, kernel_size=1, bias=False), ) self.sigmoid = nn.Sigmoid() def forward(self, x): attn = self.mlp(self.avg_pool(x)) + self.mlp(self.max_pool(x)) return x * self.sigmoid(attn) class SpatialAttention(nn.Module): def __init__(self, kernel_size: int = 7): super().__init__() padding = kernel_size // 2 self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_map = torch.mean(x, dim=1, keepdim=True) max_map, _ = torch.max(x, dim=1, keepdim=True) attn = torch.cat([avg_map, max_map], dim=1) attn = self.sigmoid(self.conv(attn)) return x * attn class CBAMBlock(nn.Module): def __init__(self, channels: int, reduction: int = 16, kernel_size: int = 7): super().__init__() self.ca = ChannelAttention(channels, reduction) self.sa = SpatialAttention(kernel_size) def forward(self, x): return self.sa(self.ca(x)) class ConvNeXtWithCBAM(nn.Module): def __init__(self, backbone: nn.Module, n_class: int, dropout: float = 0.65): super().__init__() self.features = backbone.features self.avgpool = backbone.avgpool self.norm = backbone.classifier[0] out_channels = backbone.classifier[2].in_features self.attn = CBAMBlock(out_channels, reduction=16, kernel_size=7) self.head = nn.Sequential( nn.Flatten(1), nn.Dropout(p=dropout), nn.Linear(out_channels, n_class), ) def forward(self, x): x = self.features(x) x = self.attn(x) x = self.avgpool(x) x = self.norm(x) x = self.head(x) return x def set_convnext_stochastic_depth(model: nn.Module, p: float = 0.2): for m in model.modules(): if hasattr(m, "stochastic_depth") and hasattr(m.stochastic_depth, "p"): m.stochastic_depth.p = p def build_convnext_cbam( n_class: int, dropout: float = 0.65, sd_prob: float = 0.2, weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1, ): backbone = models.convnext_small(weights=weights) set_convnext_stochastic_depth(backbone, p=sd_prob) model = ConvNeXtWithCBAM(backbone=backbone, n_class=n_class, dropout=dropout) return model