from typing import Tuple import nnAudio.features.cqt as nnAudio import torch import torch.nn as nn import torch.nn.functional as F import torchaudio.transforms as T class GeM(nn.Module): def __init__(self, p=3, eps=1e-6): super(GeM, self).__init__() self.p = nn.Parameter(torch.ones(1) * p) self.eps = eps def forward(self, x): return self.gem(x, p=self.p, eps=self.eps) def gem(self, x, p=3, eps=1e-6): return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1.0 / p) def __repr__(self): return f"{self.__class__.__name__}(p={self.p.data.tolist()[0]:.4f}, eps={str(self.eps)})" class IBN(nn.Module): r"""Instance-Batch Normalization layer from `"Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net" ` Args: planes (int): Number of channels for the input tensor ratio (float): Ratio of instance normalization in the IBN layer """ def __init__(self, planes, ratio): super(IBN, self).__init__() self.half = int(planes * ratio) self.IN = nn.InstanceNorm2d(self.half, affine=True) self.BN = nn.BatchNorm2d(planes - self.half) def forward(self, x): split = torch.split(x, self.half, 1) out1 = self.IN(split[0].contiguous()) out2 = self.BN(split[1].contiguous()) out = torch.cat((out1, out2), 1) return out class Bottleneck(nn.Module): expansion: int = 4 def __init__( self, in_channels: int, out_channels: int, last: bool = False, downsample=None, stride=1, bias: bool = True ): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias) if not last: # Apply Instance normalization in first half channels (ratio=0.5) self.ibn = IBN(out_channels, ratio=0.5) else: self.ibn = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=bias) self.batch_norm2 = nn.BatchNorm2d(out_channels) self.conv3 = nn.Conv2d( out_channels, out_channels * self.expansion, kernel_size=1, stride=1, padding=0, bias=bias ) self.batch_norm3 = nn.BatchNorm2d(out_channels * self.expansion) self.downsample = downsample self.stride = stride self.relu = nn.ReLU() def forward(self, x: torch.Tensor): residual = x.clone() x = self.conv1(x) x = self.ibn(x) x = self.relu(x) x = self.conv2(x) x = self.batch_norm2(x) x = self.relu(x) x = self.conv3(x) x = self.batch_norm3(x) x = self.relu(x) if self.downsample is not None: residual = self.downsample(residual) out = residual + x out = self.relu(out) return out class Resnet50(nn.Module): def __init__( self, ResBlock: Bottleneck, emb_dim: int = 2048, num_channels: int = 1, num_classes: int = 8858, sr: int = 22050, hop_lenght: int = 512, n_bins=84, bins_per_octave=12, window="hann", compress_ratio: int = 20, tempo_factors: Tuple[float, float] = None, ) -> None: super(Resnet50, self).__init__() self.in_channels = 64 self.cqt = nnAudio.CQT2010v2( sr=sr, hop_length=hop_lenght, n_bins=n_bins, bins_per_octave=bins_per_octave, window=window, output_format="Complex", verbose=False, ) self.compress = nn.AvgPool2d((1, compress_ratio)) self.time_strech = T.TimeStretch(n_freq=n_bins) self.tempo_factors = tempo_factors self.conv1 = nn.Conv2d( in_channels=num_channels, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False ) self.batch_norm1 = nn.BatchNorm2d(num_features=64) self.relu = nn.ReLU() self.max_pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(ResBlock, blocks=3, planes=64, stride=1) self.layer2 = self._make_layer(ResBlock, blocks=4, planes=128, stride=2) self.layer3 = self._make_layer(ResBlock, blocks=6, planes=256, stride=2) self.layer4 = self._make_layer(ResBlock, blocks=3, planes=512, stride=1, last=True) self.gem_pool = GeM() self.bn_fc = nn.BatchNorm1d(emb_dim) self.fc = nn.Linear(emb_dim, num_classes, bias=False) nn.init.kaiming_normal_(self.fc.weight) def _make_layer(self, ResBlock: Bottleneck, blocks: int, planes: int, stride: int = 1, last: bool = False): downsample = None if stride != 1 or self.in_channels != planes * ResBlock.expansion: downsample = nn.Sequential( nn.Conv2d(self.in_channels, planes * ResBlock.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * ResBlock.expansion), ) layers = [] layers.append( ResBlock(in_channels=self.in_channels, out_channels=planes, stride=stride, downsample=downsample, last=last) ) self.in_channels = planes * ResBlock.expansion for _ in range(1, blocks): layers.append(ResBlock(in_channels=self.in_channels, out_channels=planes, last=last)) return nn.Sequential(*layers) def forward(self, x: torch.Tensor): x = self.cqt(x) # Time-strech requires complex tensors, that's why cqt function returns complex x = torch.view_as_complex(x) if self.tempo_factors is not None: rate = abs(self.tempo_factors[1] - self.tempo_factors[0]) * torch.rand(1).item() + min(self.tempo_factors) strech = ( abs(1 - rate) > 7e-2 ) # if the strech ratio is too close to 1 (i.e. 0.93 < ratio < 1.07), skip time streching if self.training and strech: x = self.time_strech(x, rate) # Compress the magnitude of the CQT x = self.compress(torch.abs(x)) # Unsqueeze to simulate 1-channel image x = self.conv1(x.unsqueeze(1)) x = self.batch_norm1(x) x = self.relu(x) x = self.max_pool1(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) f_t = self.gem_pool(x) f_t = torch.flatten(f_t, start_dim=1) f_c = self.bn_fc(f_t) cls = self.fc(f_c) return dict(f_t=f_t, f_c=f_c, cls=cls)