""" Spectral Residual Encoder (ResNet-Audio) - NAN SAFE """ import torch import torch.nn as nn import torchaudio.functional as F import logging import math logger = logging.getLogger(__name__) class ResBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.downsample = None if stride != 1 or in_channels != out_channels: self.downsample = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class SpectralEncoder(nn.Module): def __init__(self, config: dict): super().__init__() audio_cfg = config["audio"] model_cfg = config["model"]["spectral_encoder"] self.n_fft = audio_cfg["n_fft"] self.hop_length = audio_cfg["hop_length"] self.n_mels = audio_cfg["n_mels"] self.sample_rate = audio_cfg["sample_rate"] mel_basis = F.melscale_fbanks( n_freqs=(self.n_fft // 2) + 1, f_min=0.0, f_max=self.sample_rate / 2.0, n_mels=self.n_mels, sample_rate=self.sample_rate, norm='slaney', mel_scale='htk', ) self.register_buffer('mel_basis', mel_basis) window = torch.hann_window(self.n_fft) self.register_buffer('window', window) self.output_dim = model_cfg["output_dim"] self.stem = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True) ) self.layer1 = self._make_layer(32, 64, stride=(2, 1)) self.layer2 = self._make_layer(64, 128, stride=(2, 1)) self.layer3 = self._make_layer(128, 256, stride=(2, 1)) self.layer4 = self._make_layer(256, self.output_dim, stride=(2, 1)) self.freq_pool = nn.AdaptiveAvgPool2d((1, None)) def _make_layer(self, in_c, out_c, stride): return ResBlock(in_c, out_c, stride) def _safe_stft(self, waveform): if not waveform.is_contiguous(): waveform = waveform.contiguous() complex_spec = torch.stft( waveform, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, center=True, pad_mode='reflect', normalized=False, onesided=True, return_complex=True ) real = complex_spec.real imag = complex_spec.imag # Clamp inputs to pow() to prevent Infinity real = torch.clamp(real, min=-1e4, max=1e4) imag = torch.clamp(imag, min=-1e4, max=1e4) mag_spec = torch.sqrt(real.pow(2) + imag.pow(2) + 1e-9) return mag_spec def forward(self, waveform: torch.Tensor) -> torch.Tensor: with torch.no_grad(): mag_spec = self._safe_stft(waveform) melspec = torch.matmul(self.mel_basis.transpose(0, 1), mag_spec) # Safe Log: Clamp minimum to prevent -inf melspec = torch.log10(torch.clamp(melspec, min=1e-5, max=1e5)) x = melspec.unsqueeze(1) x = self.stem(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.freq_pool(x) x = x.squeeze(2) x = x.transpose(1, 2) return x