indicguard / src /models /spectral_encoder.py
realruneet's picture
Update src/models/spectral_encoder.py
cac091b verified
"""
Spectral Residual Encoder (ResNet-Audio) - NAN SAFE
"""
import torch
import torch.nn as nn
import torchaudio.functional as F
import logging
import math
logger = logging.getLogger(__name__)
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = None
if stride != 1 or in_channels != out_channels:
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class SpectralEncoder(nn.Module):
def __init__(self, config: dict):
super().__init__()
audio_cfg = config["audio"]
model_cfg = config["model"]["spectral_encoder"]
self.n_fft = audio_cfg["n_fft"]
self.hop_length = audio_cfg["hop_length"]
self.n_mels = audio_cfg["n_mels"]
self.sample_rate = audio_cfg["sample_rate"]
mel_basis = F.melscale_fbanks(
n_freqs=(self.n_fft // 2) + 1,
f_min=0.0,
f_max=self.sample_rate / 2.0,
n_mels=self.n_mels,
sample_rate=self.sample_rate,
norm='slaney',
mel_scale='htk',
)
self.register_buffer('mel_basis', mel_basis)
window = torch.hann_window(self.n_fft)
self.register_buffer('window', window)
self.output_dim = model_cfg["output_dim"]
self.stem = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True)
)
self.layer1 = self._make_layer(32, 64, stride=(2, 1))
self.layer2 = self._make_layer(64, 128, stride=(2, 1))
self.layer3 = self._make_layer(128, 256, stride=(2, 1))
self.layer4 = self._make_layer(256, self.output_dim, stride=(2, 1))
self.freq_pool = nn.AdaptiveAvgPool2d((1, None))
def _make_layer(self, in_c, out_c, stride):
return ResBlock(in_c, out_c, stride)
def _safe_stft(self, waveform):
if not waveform.is_contiguous():
waveform = waveform.contiguous()
complex_spec = torch.stft(
waveform,
n_fft=self.n_fft,
hop_length=self.hop_length,
window=self.window,
center=True,
pad_mode='reflect',
normalized=False,
onesided=True,
return_complex=True
)
real = complex_spec.real
imag = complex_spec.imag
# Clamp inputs to pow() to prevent Infinity
real = torch.clamp(real, min=-1e4, max=1e4)
imag = torch.clamp(imag, min=-1e4, max=1e4)
mag_spec = torch.sqrt(real.pow(2) + imag.pow(2) + 1e-9)
return mag_spec
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
mag_spec = self._safe_stft(waveform)
melspec = torch.matmul(self.mel_basis.transpose(0, 1), mag_spec)
# Safe Log: Clamp minimum to prevent -inf
melspec = torch.log10(torch.clamp(melspec, min=1e-5, max=1e5))
x = melspec.unsqueeze(1)
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.freq_pool(x)
x = x.squeeze(2)
x = x.transpose(1, 2)
return x