StrokeMitra-API / src /features /spectrogram_builder.py
DhruvB1906's picture
Upload folder using huggingface_hub
4e9a3bc verified
"""Spectrogram and scalogram generation."""
import logging
import numpy as np
import librosa
from scipy import signal
from src.features.schemas import SpectrogramFeatures
logger = logging.getLogger(__name__)
class SpectrogramBuilder:
"""Build log-mel spectrogram and wavelet scalogram."""
def __init__(
self,
n_mels: int = 128,
n_fft: int = 2048,
hop_length: int = 512,
fmin: float = 0,
fmax: float = 8000,
use_wavelet: bool = True,
wavelet_type: str = "morl",
num_scales: int = 128,
):
"""Initialize spectrogram builder."""
self.n_mels = n_mels
self.n_fft = n_fft
self.hop_length = hop_length
self.fmin = fmin
self.fmax = fmax
self.use_wavelet = use_wavelet
self.wavelet_type = wavelet_type
self.num_scales = num_scales
def build(self, waveform: np.ndarray, sr: int) -> SpectrogramFeatures:
"""
Build spectrogram features.
Args:
waveform: Audio waveform
sr: Sample rate
Returns:
SpectrogramFeatures with log-mel and optional wavelet scalogram
"""
logger.debug(f"Building spectrogram: n_mels={self.n_mels}")
# 1. Compute log-mel spectrogram
mel_spec = librosa.feature.melspectrogram(
y=waveform,
sr=sr,
n_fft=self.n_fft,
hop_length=self.hop_length,
n_mels=self.n_mels,
fmin=self.fmin,
fmax=self.fmax,
power=2.0,
)
# Convert to log scale
log_mel = librosa.power_to_db(mel_spec, ref=np.max)
# 2. Compute wavelet scalogram (optional)
wavelet_scalogram = None
if self.use_wavelet:
try:
wavelet_scalogram = self._compute_wavelet_scalogram(waveform, sr)
except Exception as e:
logger.warning(f"Wavelet scalogram computation failed: {e}")
# 3. Stack log-mel + scalogram as multi-channel input
stacked = None
if wavelet_scalogram is not None:
# Resize scalogram to match log-mel dimensions
if wavelet_scalogram.shape != log_mel.shape:
from scipy.ndimage import zoom
zoom_factors = (
log_mel.shape[0] / wavelet_scalogram.shape[0],
log_mel.shape[1] / wavelet_scalogram.shape[1],
)
wavelet_scalogram = zoom(wavelet_scalogram, zoom_factors, order=1)
# Stack as (2, freq, time)
stacked = np.stack([log_mel, wavelet_scalogram], axis=0)
else:
# Use only log-mel, duplicated
stacked = np.stack([log_mel, log_mel], axis=0)
logger.info(f"Spectrogram built: log_mel={log_mel.shape}, stacked={stacked.shape}")
return SpectrogramFeatures(
log_mel=log_mel,
wavelet_scalogram=wavelet_scalogram,
stacked=stacked,
n_mels=self.n_mels,
hop_length=self.hop_length,
sr=sr,
)
def _compute_wavelet_scalogram(self, waveform: np.ndarray, sr: int) -> np.ndarray:
"""Compute wavelet CWT scalogram."""
# Define scales
scales = np.arange(1, self.num_scales + 1)
# Compute CWT
coefficients, frequencies = signal.cwt(
waveform, signal.morlet2, scales, w=5.0
)
# Convert to power (magnitude squared)
scalogram = np.abs(coefficients) ** 2
# Log scale
scalogram = np.log1p(scalogram)
# Normalize
scalogram = (scalogram - scalogram.min()) / (scalogram.max() - scalogram.min() + 1e-8)
return scalogram