File size: 3,765 Bytes
4e9a3bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Spectrogram and scalogram generation."""

import logging
import numpy as np
import librosa
from scipy import signal

from src.features.schemas import SpectrogramFeatures

logger = logging.getLogger(__name__)


class SpectrogramBuilder:
    """Build log-mel spectrogram and wavelet scalogram."""

    def __init__(
        self,
        n_mels: int = 128,
        n_fft: int = 2048,
        hop_length: int = 512,
        fmin: float = 0,
        fmax: float = 8000,
        use_wavelet: bool = True,
        wavelet_type: str = "morl",
        num_scales: int = 128,
    ):
        """Initialize spectrogram builder."""
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.fmin = fmin
        self.fmax = fmax
        self.use_wavelet = use_wavelet
        self.wavelet_type = wavelet_type
        self.num_scales = num_scales

    def build(self, waveform: np.ndarray, sr: int) -> SpectrogramFeatures:
        """
        Build spectrogram features.

        Args:
            waveform: Audio waveform
            sr: Sample rate

        Returns:
            SpectrogramFeatures with log-mel and optional wavelet scalogram
        """
        logger.debug(f"Building spectrogram: n_mels={self.n_mels}")

        # 1. Compute log-mel spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=waveform,
            sr=sr,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            n_mels=self.n_mels,
            fmin=self.fmin,
            fmax=self.fmax,
            power=2.0,
        )

        # Convert to log scale
        log_mel = librosa.power_to_db(mel_spec, ref=np.max)

        # 2. Compute wavelet scalogram (optional)
        wavelet_scalogram = None
        if self.use_wavelet:
            try:
                wavelet_scalogram = self._compute_wavelet_scalogram(waveform, sr)
            except Exception as e:
                logger.warning(f"Wavelet scalogram computation failed: {e}")

        # 3. Stack log-mel + scalogram as multi-channel input
        stacked = None
        if wavelet_scalogram is not None:
            # Resize scalogram to match log-mel dimensions
            if wavelet_scalogram.shape != log_mel.shape:
                from scipy.ndimage import zoom

                zoom_factors = (
                    log_mel.shape[0] / wavelet_scalogram.shape[0],
                    log_mel.shape[1] / wavelet_scalogram.shape[1],
                )
                wavelet_scalogram = zoom(wavelet_scalogram, zoom_factors, order=1)

            # Stack as (2, freq, time)
            stacked = np.stack([log_mel, wavelet_scalogram], axis=0)
        else:
            # Use only log-mel, duplicated
            stacked = np.stack([log_mel, log_mel], axis=0)

        logger.info(f"Spectrogram built: log_mel={log_mel.shape}, stacked={stacked.shape}")

        return SpectrogramFeatures(
            log_mel=log_mel,
            wavelet_scalogram=wavelet_scalogram,
            stacked=stacked,
            n_mels=self.n_mels,
            hop_length=self.hop_length,
            sr=sr,
        )

    def _compute_wavelet_scalogram(self, waveform: np.ndarray, sr: int) -> np.ndarray:
        """Compute wavelet CWT scalogram."""
        # Define scales
        scales = np.arange(1, self.num_scales + 1)

        # Compute CWT
        coefficients, frequencies = signal.cwt(
            waveform, signal.morlet2, scales, w=5.0
        )

        # Convert to power (magnitude squared)
        scalogram = np.abs(coefficients) ** 2

        # Log scale
        scalogram = np.log1p(scalogram)

        # Normalize
        scalogram = (scalogram - scalogram.min()) / (scalogram.max() - scalogram.min() + 1e-8)

        return scalogram