| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import math |
| |
|
| | class NormalizedMultiScaleAttention(nn.Module): |
| | """ |
| | Normalized Multi-Scale Attention (Normalized-MSA) module |
| | Enhances multi-scale feature representation by balancing computational efficiency with representation strength. |
| | """ |
| | def __init__(self, in_channels, scales=[1, 2, 4]): |
| | super(NormalizedMultiScaleAttention, self).__init__() |
| | self.scales = scales |
| | self.in_channels = in_channels |
| | |
| | |
| | self.spatial_convs = nn.ModuleList([ |
| | nn.Sequential( |
| | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False), |
| | nn.BatchNorm2d(in_channels), |
| | nn.Sigmoid() |
| | ) for _ in range(len(scales)) |
| | ]) |
| | |
| | |
| | self.edge_conv = nn.Sequential( |
| | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False), |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True) |
| | ) |
| | |
| | |
| | self.scale_weights = nn.Parameter(torch.ones(len(scales)) / len(scales)) |
| | |
| | self._init_weights() |
| | |
| | def _init_weights(self): |
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
| | nn.init.constant_(m.weight, 1) |
| | nn.init.constant_(m.bias, 0) |
| | |
| | def forward(self, x): |
| | batch_size, channels, height, width = x.size() |
| | multi_scale_features = [] |
| | |
| | |
| | edge_features = self.edge_conv(x) |
| | |
| | for i, scale in enumerate(self.scales): |
| | |
| | if scale == 1: |
| | x_s = x |
| | else: |
| | |
| | x_s = F.avg_pool2d(x, kernel_size=scale, stride=scale) |
| | |
| | |
| | spatial_attn = self.spatial_convs[i](x_s) |
| | |
| | |
| | |
| | x_flat = x_s.view(batch_size, channels, -1) |
| | x_t = x_flat.transpose(1, 2) |
| | |
| | |
| | norm_factor = math.sqrt(x_flat.size(2)) |
| | channel_attn = torch.bmm(x_flat, x_t) / norm_factor |
| | channel_attn = F.softmax(channel_attn, dim=2) |
| | |
| | |
| | attended = torch.bmm(channel_attn, x_flat) |
| | attended = attended.view(batch_size, channels, *x_s.size()[2:]) |
| | |
| | |
| | attended = attended * spatial_attn |
| | |
| | |
| | if scale != 1: |
| | attended = F.interpolate(attended, size=(height, width), mode='bilinear', align_corners=False) |
| | |
| | multi_scale_features.append(attended) |
| | |
| | |
| | weighted_features = [] |
| | for i, feature in enumerate(multi_scale_features): |
| | weighted_features.append(feature * self.scale_weights[i]) |
| | |
| | |
| | output = torch.stack(weighted_features, dim=0).sum(dim=0) |
| | |
| | |
| | output = output + 0.1 * edge_features |
| | |
| | return output |
| |
|
| | class EntropyOptimizedGating(nn.Module): |
| | """ |
| | Entropy-Optimized Gating (EOG) module |
| | Feature redundancy is adaptively suppressed using a normalized entropy function. |
| | """ |
| | def __init__(self, channels, beta=0.3, epsilon=1e-5): |
| | super(EntropyOptimizedGating, self).__init__() |
| | self.channels = channels |
| | self.beta = nn.Parameter(torch.tensor([beta])) |
| | self.epsilon = epsilon |
| | |
| | self.residual_weight = nn.Parameter(torch.tensor([0.2])) |
| | |
| | def forward(self, x): |
| | batch_size, channels, height, width = x.size() |
| | |
| | |
| | entropies = [] |
| | gates = [] |
| | |
| | for c in range(channels): |
| | |
| | channel = x[:, c, :, :] |
| | |
| | |
| | abs_channel = torch.abs(channel) |
| | sum_abs = torch.sum(abs_channel, dim=(1, 2), keepdim=True) + self.epsilon |
| | norm_prob = abs_channel / sum_abs |
| | |
| | |
| | |
| | log_prob = torch.log(norm_prob + self.epsilon) |
| | entropy = -torch.sum(norm_prob * log_prob, dim=(1, 2)) |
| | |
| | |
| | max_entropy = math.log(height * width) |
| | norm_entropy = entropy / max_entropy |
| | |
| | |
| | gate = (norm_entropy > self.beta).float() |
| | |
| | entropies.append(norm_entropy) |
| | gates.append(gate) |
| | |
| | |
| | entropies = torch.stack(entropies, dim=1) |
| | gates = torch.stack(gates, dim=1) |
| | |
| | |
| | gates = gates.view(batch_size, channels, 1, 1) |
| | gated_output = x * gates |
| | |
| | |
| | output = gated_output + self.residual_weight * x |
| | |
| | return output |
| |
|
| | class EOANetModule(nn.Module): |
| | """ |
| | Entropy-Optimized Attention Network (EOANet) module |
| | Combines Normalized Multi-Scale Attention with Entropy-Optimized Gating |
| | """ |
| | def __init__(self, in_channels, scales=[1, 2, 4], beta=0.5): |
| | super(EOANetModule, self).__init__() |
| | self.msa = NormalizedMultiScaleAttention(in_channels, scales) |
| | self.eog = EntropyOptimizedGating(in_channels, beta) |
| | |
| | def forward(self, x): |
| | |
| | x_msa = self.msa(x) |
| | |
| | |
| | x_eog = self.eog(x_msa) |
| | |
| | return x_eog |