import torch import torch.nn.functional as F from torch import nn class Attention(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16): super(Attention, self).__init__() attention_channel = max(int(in_planes * reduction), min_channel) self.kernel_size = kernel_size self.kernel_num = kernel_num self.temperature = 1.0 self.avgpool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Conv2d(in_planes, attention_channel, 1, bias=False) self.bn = nn.BatchNorm2d(attention_channel) self.relu = nn.ReLU(inplace=True) self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True) self.func_channel = self.get_channel_attention if in_planes == groups and in_planes == out_planes: # depth-wise convolution self.func_filter = self.skip else: self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True) self.func_filter = self.get_filter_attention if kernel_size == 1: # point-wise convolution self.func_spatial = self.skip else: self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True) self.func_spatial = self.get_spatial_attention if kernel_num == 1: self.func_kernel = self.skip else: self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True) self.func_kernel = self.get_kernel_attention self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) if isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def update_temperature(self, temperature): self.temperature = temperature @staticmethod def skip(_): return 1.0 def get_channel_attention(self, x): channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature) return channel_attention def get_filter_attention(self, x): filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature) return filter_attention def get_spatial_attention(self, x): spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size) spatial_attention = torch.sigmoid(spatial_attention / self.temperature) return spatial_attention def get_kernel_attention(self, x): kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1) kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1) return kernel_attention def forward(self, x): x = self.avgpool(x) x = self.fc(x) x = self.bn(x) x = self.relu(x) return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x) class ODConv2d(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, reduction=0.0625, kernel_num=4): super(ODConv2d, self).__init__() self.in_planes = in_planes self.out_planes = out_planes self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.kernel_num = kernel_num self.attention = Attention(in_planes, out_planes, kernel_size, groups=groups, reduction=reduction, kernel_num=kernel_num) self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//groups, kernel_size, kernel_size), requires_grad=True) self._initialize_weights() if self.kernel_size == 1 and self.kernel_num == 1: self._forward_impl = self._forward_impl_pw1x else: self._forward_impl = self._forward_impl_common def _initialize_weights(self): for i in range(self.kernel_num): nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu') def update_temperature(self, temperature): self.attention.update_temperature(temperature) def _forward_impl_common(self, x): # Multiplying channel attention (or filter attention) to weights and feature maps are equivalent, # while we observe that when using the latter method the models will run faster with less gpu memory cost. channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x) batch_size, in_planes, height, width = x.size() x = x * channel_attention x = x.reshape(1, -1, height, width) aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0) aggregate_weight = torch.sum(aggregate_weight, dim=1).view( [-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size]) output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups * batch_size) output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1)) output = output * filter_attention return output def _forward_impl_pw1x(self, x): channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x) x = x * channel_attention output = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) output = output * filter_attention return output def forward(self, x): return self._forward_impl(x) class BasicConv(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): super(BasicConv, self).__init__() self.out_channels = out_planes self.conv = ODConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups) self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None self.relu = nn.ReLU() if relu else None def forward(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.relu is not None: x = self.relu(x) return x class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class ChannelGate(nn.Module): def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): super(ChannelGate, self).__init__() self.gate_channels = gate_channels self.mlp = nn.Sequential( Flatten(), nn.Linear(gate_channels, gate_channels // reduction_ratio), nn.ReLU(), nn.Linear(gate_channels // reduction_ratio, gate_channels) ) self.pool_types = pool_types def forward(self, x): channel_att_sum = None for pool_type in self.pool_types: if pool_type=='avg': avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( avg_pool ) elif pool_type=='max': max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( max_pool ) elif pool_type=='lp': lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( lp_pool ) elif pool_type=='lse': # LSE pool only lse_pool = logsumexp_2d(x) channel_att_raw = self.mlp( lse_pool ) if channel_att_sum is None: channel_att_sum = channel_att_raw else: channel_att_sum = channel_att_sum + channel_att_raw scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) return x * scale def logsumexp_2d(tensor): tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() return outputs class ChannelPool(nn.Module): def forward(self, x): return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) class SpatialGate(nn.Module): def __init__(self): super(SpatialGate, self).__init__() kernel_size = 7 self.compress = ChannelPool() self.spatial = ODConv2d(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2) def forward(self, x): x_compress = self.compress(x) x_out = self.spatial(x_compress) scale = torch.sigmoid(x_out) # broadcasting return x * scale class CBAM(nn.Module): def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): super(CBAM, self).__init__() self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) self.no_spatial=no_spatial if not no_spatial: self.SpatialGate = SpatialGate() def forward(self, x): x_out = self.ChannelGate(x) if not self.no_spatial: x_out = self.SpatialGate(x_out) return x_out class Dual(nn.Module): def __init__(self): super(Dual, self).__init__() self.feature_extractor2 = nn.Sequential( nn.Conv2d(1,64,3,1,1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d((2,2),(2,2)), nn.Conv2d(64,64,3,1,1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d((4,4),(4,4)), nn.Conv2d(64,128,3,1,1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d((4,4),(4,4)), nn.Conv2d(128,128,3,1,1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d((4,4),(4,4)) ) self.cbam = CBAM(128) self.gru = nn.GRU( input_size=128, hidden_size=256, batch_first=True ) self.fc2 = nn.Sequential( nn.Flatten(), nn.Linear(256,512), nn.ReLU(), nn.Linear(512,5) ) self.fc3 = nn.Linear(5,5) def forward(self, mfcc): x = self.feature_extractor2(mfcc) x = self.cbam(x) if x.dim() != 4: raise ValueError(f"Invalid shape after CNN: {x.shape}") # (B,128,1,1) -> (B,128) x = x.squeeze(-1).squeeze(-1) # -> (B,1,128) x = x.unsqueeze(1) x, _ = self.gru(x) x = self.fc2(x) x = self.fc3(x) return x