| |
| """Pure Python implementation of Cohere Transcribe mel spectrogram preprocessing. |
| |
| This matches the exact preprocessing used by the Cohere model, without requiring |
| the transformers library's feature extractor. |
| """ |
|
|
| import numpy as np |
|
|
|
|
| class CohereMelSpectrogram: |
| """Mel spectrogram preprocessor matching Cohere Transcribe's exact parameters.""" |
|
|
| def __init__( |
| self, |
| sample_rate=16000, |
| n_fft=1024, |
| hop_length=160, |
| n_mels=128, |
| fmin=0.0, |
| fmax=8000.0, |
| ): |
| self.sample_rate = sample_rate |
| self.n_fft = n_fft |
| self.hop_length = hop_length |
| self.n_mels = n_mels |
| self.fmin = fmin |
| self.fmax = fmax |
|
|
| |
| self.mel_filters = self._create_mel_filterbank() |
|
|
| def _create_mel_filterbank(self): |
| """Create mel filterbank matrix.""" |
| |
| def hz_to_mel(hz): |
| return 2595 * np.log10(1 + hz / 700) |
|
|
| def mel_to_hz(mel): |
| return 700 * (10 ** (mel / 2595) - 1) |
|
|
| |
| mel_min = hz_to_mel(self.fmin) |
| mel_max = hz_to_mel(self.fmax) |
| mel_points = np.linspace(mel_min, mel_max, self.n_mels + 2) |
| hz_points = mel_to_hz(mel_points) |
|
|
| |
| bin_points = np.floor((self.n_fft + 1) * hz_points / self.sample_rate).astype(int) |
|
|
| |
| fbank = np.zeros((self.n_mels, self.n_fft // 2 + 1)) |
| for m in range(1, self.n_mels + 1): |
| f_left = bin_points[m - 1] |
| f_center = bin_points[m] |
| f_right = bin_points[m + 1] |
|
|
| |
| for k in range(f_left, f_center): |
| fbank[m - 1, k] = (k - f_left) / (f_center - f_left) |
|
|
| |
| for k in range(f_center, f_right): |
| fbank[m - 1, k] = (f_right - k) / (f_right - f_center) |
|
|
| return fbank |
|
|
| def __call__(self, audio): |
| """ |
| Compute mel spectrogram from audio. |
| |
| Args: |
| audio: 1D numpy array of audio samples (float32, range roughly -1 to 1) |
| |
| Returns: |
| mel: (1, n_mels, n_frames) numpy array |
| """ |
| |
| audio = audio.astype(np.float32) |
|
|
| |
| n_samples = len(audio) |
| n_frames = 1 + (n_samples - self.n_fft) // self.hop_length |
|
|
| |
| stft = self._stft(audio) |
|
|
| |
| power = np.abs(stft) ** 2 |
|
|
| |
| mel = np.dot(self.mel_filters, power) |
|
|
| |
| mel = np.log10(np.maximum(mel, 1e-10)) |
|
|
| |
| mel = mel[np.newaxis, :, :] |
|
|
| return mel |
|
|
| def _stft(self, audio): |
| """Compute Short-Time Fourier Transform.""" |
| |
| pad_length = self.n_fft // 2 |
| audio_padded = np.pad(audio, (pad_length, pad_length), mode="reflect") |
|
|
| |
| window = np.hanning(self.n_fft) |
|
|
| |
| n_frames = 1 + (len(audio_padded) - self.n_fft) // self.hop_length |
|
|
| |
| stft = np.zeros((self.n_fft // 2 + 1, n_frames), dtype=np.complex64) |
|
|
| |
| for i in range(n_frames): |
| start = i * self.hop_length |
| frame = audio_padded[start : start + self.n_fft] |
| windowed = frame * window |
| fft = np.fft.rfft(windowed) |
| stft[:, i] = fft |
|
|
| return stft |
|
|