File size: 3,656 Bytes
50731c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python3
"""Pure Python implementation of Cohere Transcribe mel spectrogram preprocessing.

This matches the exact preprocessing used by the Cohere model, without requiring
the transformers library's feature extractor.
"""

import numpy as np


class CohereMelSpectrogram:
    """Mel spectrogram preprocessor matching Cohere Transcribe's exact parameters."""

    def __init__(
        self,
        sample_rate=16000,
        n_fft=1024,
        hop_length=160,
        n_mels=128,
        fmin=0.0,
        fmax=8000.0,
    ):
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.n_mels = n_mels
        self.fmin = fmin
        self.fmax = fmax

        # Create mel filterbank
        self.mel_filters = self._create_mel_filterbank()

    def _create_mel_filterbank(self):
        """Create mel filterbank matrix."""
        # Convert Hz to Mel
        def hz_to_mel(hz):
            return 2595 * np.log10(1 + hz / 700)

        def mel_to_hz(mel):
            return 700 * (10 ** (mel / 2595) - 1)

        # Create mel scale
        mel_min = hz_to_mel(self.fmin)
        mel_max = hz_to_mel(self.fmax)
        mel_points = np.linspace(mel_min, mel_max, self.n_mels + 2)
        hz_points = mel_to_hz(mel_points)

        # Convert to FFT bin numbers
        bin_points = np.floor((self.n_fft + 1) * hz_points / self.sample_rate).astype(int)

        # Create filterbank
        fbank = np.zeros((self.n_mels, self.n_fft // 2 + 1))
        for m in range(1, self.n_mels + 1):
            f_left = bin_points[m - 1]
            f_center = bin_points[m]
            f_right = bin_points[m + 1]

            # Left slope
            for k in range(f_left, f_center):
                fbank[m - 1, k] = (k - f_left) / (f_center - f_left)

            # Right slope
            for k in range(f_center, f_right):
                fbank[m - 1, k] = (f_right - k) / (f_right - f_center)

        return fbank

    def __call__(self, audio):
        """
        Compute mel spectrogram from audio.

        Args:
            audio: 1D numpy array of audio samples (float32, range roughly -1 to 1)

        Returns:
            mel: (1, n_mels, n_frames) numpy array
        """
        # Ensure float32
        audio = audio.astype(np.float32)

        # Add padding to match transformers behavior
        n_samples = len(audio)
        n_frames = 1 + (n_samples - self.n_fft) // self.hop_length

        # Compute STFT
        stft = self._stft(audio)

        # Compute power spectrogram
        power = np.abs(stft) ** 2

        # Apply mel filterbank
        mel = np.dot(self.mel_filters, power)

        # Log mel spectrogram (matching transformers)
        mel = np.log10(np.maximum(mel, 1e-10))

        # Add batch dimension
        mel = mel[np.newaxis, :, :]

        return mel

    def _stft(self, audio):
        """Compute Short-Time Fourier Transform."""
        # Pad audio
        pad_length = self.n_fft // 2
        audio_padded = np.pad(audio, (pad_length, pad_length), mode="reflect")

        # Hann window
        window = np.hanning(self.n_fft)

        # Calculate number of frames
        n_frames = 1 + (len(audio_padded) - self.n_fft) // self.hop_length

        # Initialize STFT matrix
        stft = np.zeros((self.n_fft // 2 + 1, n_frames), dtype=np.complex64)

        # Compute STFT
        for i in range(n_frames):
            start = i * self.hop_length
            frame = audio_padded[start : start + self.n_fft]
            windowed = frame * window
            fft = np.fft.rfft(windowed)
            stft[:, i] = fft

        return stft