| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16): |
| super(Attention, self).__init__() |
| attention_channel = max(int(in_planes * reduction), min_channel) |
| self.kernel_size = kernel_size |
| self.kernel_num = kernel_num |
| self.temperature = 1.0 |
|
|
| self.avgpool = nn.AdaptiveAvgPool2d(1) |
| self.fc = nn.Conv2d(in_planes, attention_channel, 1, bias=False) |
| self.bn = nn.BatchNorm2d(attention_channel) |
| self.relu = nn.ReLU(inplace=True) |
|
|
| self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True) |
| self.func_channel = self.get_channel_attention |
|
|
| if in_planes == groups and in_planes == out_planes: |
| self.func_filter = self.skip |
| else: |
| self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True) |
| self.func_filter = self.get_filter_attention |
|
|
| if kernel_size == 1: |
| self.func_spatial = self.skip |
| else: |
| self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True) |
| self.func_spatial = self.get_spatial_attention |
|
|
| if kernel_num == 1: |
| self.func_kernel = self.skip |
| else: |
| self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True) |
| self.func_kernel = self.get_kernel_attention |
|
|
| self._initialize_weights() |
|
|
| def _initialize_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| if isinstance(m, nn.BatchNorm2d): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
|
|
| def update_temperature(self, temperature): |
| self.temperature = temperature |
|
|
| @staticmethod |
| def skip(_): |
| return 1.0 |
|
|
| def get_channel_attention(self, x): |
| channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature) |
| return channel_attention |
|
|
| def get_filter_attention(self, x): |
| filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature) |
| return filter_attention |
|
|
| def get_spatial_attention(self, x): |
| spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size) |
| spatial_attention = torch.sigmoid(spatial_attention / self.temperature) |
| return spatial_attention |
|
|
| def get_kernel_attention(self, x): |
| kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1) |
| kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1) |
| return kernel_attention |
|
|
| def forward(self, x): |
| x = self.avgpool(x) |
| x = self.fc(x) |
| x = self.bn(x) |
| x = self.relu(x) |
| return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x) |
|
|
|
|
| class ODConv2d(nn.Module): |
| def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, |
| reduction=0.0625, kernel_num=4): |
| super(ODConv2d, self).__init__() |
| self.in_planes = in_planes |
| self.out_planes = out_planes |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.padding = padding |
| self.dilation = dilation |
| self.groups = groups |
| self.kernel_num = kernel_num |
| self.attention = Attention(in_planes, out_planes, kernel_size, groups=groups, |
| reduction=reduction, kernel_num=kernel_num) |
| self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//groups, kernel_size, kernel_size), |
| requires_grad=True) |
| self._initialize_weights() |
|
|
| if self.kernel_size == 1 and self.kernel_num == 1: |
| self._forward_impl = self._forward_impl_pw1x |
| else: |
| self._forward_impl = self._forward_impl_common |
|
|
| def _initialize_weights(self): |
| for i in range(self.kernel_num): |
| nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu') |
|
|
| def update_temperature(self, temperature): |
| self.attention.update_temperature(temperature) |
|
|
| def _forward_impl_common(self, x): |
| |
| |
| channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x) |
| batch_size, in_planes, height, width = x.size() |
| x = x * channel_attention |
| x = x.reshape(1, -1, height, width) |
| aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0) |
| aggregate_weight = torch.sum(aggregate_weight, dim=1).view( |
| [-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size]) |
| output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding, |
| dilation=self.dilation, groups=self.groups * batch_size) |
| output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1)) |
| output = output * filter_attention |
| return output |
|
|
| def _forward_impl_pw1x(self, x): |
| channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x) |
| x = x * channel_attention |
| output = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding, |
| dilation=self.dilation, groups=self.groups) |
| output = output * filter_attention |
| return output |
|
|
| def forward(self, x): |
| return self._forward_impl(x) |
|
|
|
|
| class BasicConv(nn.Module): |
| def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): |
| super(BasicConv, self).__init__() |
| self.out_channels = out_planes |
| self.conv = ODConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups) |
| self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None |
| self.relu = nn.ReLU() if relu else None |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| if self.bn is not None: |
| x = self.bn(x) |
| if self.relu is not None: |
| x = self.relu(x) |
| return x |
|
|
| class Flatten(nn.Module): |
| def forward(self, x): |
| return x.view(x.size(0), -1) |
|
|
| class ChannelGate(nn.Module): |
| def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): |
| super(ChannelGate, self).__init__() |
| self.gate_channels = gate_channels |
| self.mlp = nn.Sequential( |
| Flatten(), |
| nn.Linear(gate_channels, gate_channels // reduction_ratio), |
| nn.ReLU(), |
| nn.Linear(gate_channels // reduction_ratio, gate_channels) |
| ) |
| self.pool_types = pool_types |
| def forward(self, x): |
| channel_att_sum = None |
| for pool_type in self.pool_types: |
| if pool_type=='avg': |
| avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) |
| channel_att_raw = self.mlp( avg_pool ) |
| elif pool_type=='max': |
| max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) |
| channel_att_raw = self.mlp( max_pool ) |
| elif pool_type=='lp': |
| lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) |
| channel_att_raw = self.mlp( lp_pool ) |
| elif pool_type=='lse': |
| |
| lse_pool = logsumexp_2d(x) |
| channel_att_raw = self.mlp( lse_pool ) |
|
|
| if channel_att_sum is None: |
| channel_att_sum = channel_att_raw |
| else: |
| channel_att_sum = channel_att_sum + channel_att_raw |
|
|
| scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) |
| return x * scale |
|
|
| def logsumexp_2d(tensor): |
| tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) |
| s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) |
| outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() |
| return outputs |
|
|
| class ChannelPool(nn.Module): |
| def forward(self, x): |
| return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) |
|
|
| class SpatialGate(nn.Module): |
| def __init__(self): |
| super(SpatialGate, self).__init__() |
| kernel_size = 7 |
| self.compress = ChannelPool() |
| self.spatial = ODConv2d(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2) |
| def forward(self, x): |
| x_compress = self.compress(x) |
| x_out = self.spatial(x_compress) |
| scale = torch.sigmoid(x_out) |
| return x * scale |
|
|
| class CBAM(nn.Module): |
| def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): |
| super(CBAM, self).__init__() |
| self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) |
| self.no_spatial=no_spatial |
| if not no_spatial: |
| self.SpatialGate = SpatialGate() |
| def forward(self, x): |
| x_out = self.ChannelGate(x) |
| if not self.no_spatial: |
| x_out = self.SpatialGate(x_out) |
| return x_out |
|
|
|
|
| class Dual(nn.Module): |
|
|
| def __init__(self): |
| super(Dual, self).__init__() |
|
|
| self.feature_extractor2 = nn.Sequential( |
|
|
| nn.Conv2d(1,64,3,1,1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(), |
| nn.MaxPool2d((2,2),(2,2)), |
|
|
| nn.Conv2d(64,64,3,1,1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(), |
| nn.MaxPool2d((4,4),(4,4)), |
|
|
| nn.Conv2d(64,128,3,1,1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(), |
| nn.MaxPool2d((4,4),(4,4)), |
|
|
| nn.Conv2d(128,128,3,1,1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(), |
| nn.MaxPool2d((4,4),(4,4)) |
| ) |
|
|
| self.cbam = CBAM(128) |
|
|
| self.gru = nn.GRU( |
| input_size=128, |
| hidden_size=256, |
| batch_first=True |
| ) |
|
|
| self.fc2 = nn.Sequential( |
| nn.Flatten(), |
| nn.Linear(256,512), |
| nn.ReLU(), |
| nn.Linear(512,5) |
| ) |
|
|
| self.fc3 = nn.Linear(5,5) |
|
|
| def forward(self, mfcc): |
|
|
| x = self.feature_extractor2(mfcc) |
| x = self.cbam(x) |
|
|
| if x.dim() != 4: |
| raise ValueError(f"Invalid shape after CNN: {x.shape}") |
|
|
| |
| x = x.squeeze(-1).squeeze(-1) |
|
|
| |
| x = x.unsqueeze(1) |
|
|
| x, _ = self.gru(x) |
|
|
| x = self.fc2(x) |
| x = self.fc3(x) |
|
|
| return x |