cohere-transcribe-03-2026-coreml / cohere_mel_spectrogram.py
alexwengg's picture
Upload 17 files
50731c6 verified
#!/usr/bin/env python3
"""Pure Python implementation of Cohere Transcribe mel spectrogram preprocessing.
This matches the exact preprocessing used by the Cohere model, without requiring
the transformers library's feature extractor.
"""
import numpy as np
class CohereMelSpectrogram:
"""Mel spectrogram preprocessor matching Cohere Transcribe's exact parameters."""
def __init__(
self,
sample_rate=16000,
n_fft=1024,
hop_length=160,
n_mels=128,
fmin=0.0,
fmax=8000.0,
):
self.sample_rate = sample_rate
self.n_fft = n_fft
self.hop_length = hop_length
self.n_mels = n_mels
self.fmin = fmin
self.fmax = fmax
# Create mel filterbank
self.mel_filters = self._create_mel_filterbank()
def _create_mel_filterbank(self):
"""Create mel filterbank matrix."""
# Convert Hz to Mel
def hz_to_mel(hz):
return 2595 * np.log10(1 + hz / 700)
def mel_to_hz(mel):
return 700 * (10 ** (mel / 2595) - 1)
# Create mel scale
mel_min = hz_to_mel(self.fmin)
mel_max = hz_to_mel(self.fmax)
mel_points = np.linspace(mel_min, mel_max, self.n_mels + 2)
hz_points = mel_to_hz(mel_points)
# Convert to FFT bin numbers
bin_points = np.floor((self.n_fft + 1) * hz_points / self.sample_rate).astype(int)
# Create filterbank
fbank = np.zeros((self.n_mels, self.n_fft // 2 + 1))
for m in range(1, self.n_mels + 1):
f_left = bin_points[m - 1]
f_center = bin_points[m]
f_right = bin_points[m + 1]
# Left slope
for k in range(f_left, f_center):
fbank[m - 1, k] = (k - f_left) / (f_center - f_left)
# Right slope
for k in range(f_center, f_right):
fbank[m - 1, k] = (f_right - k) / (f_right - f_center)
return fbank
def __call__(self, audio):
"""
Compute mel spectrogram from audio.
Args:
audio: 1D numpy array of audio samples (float32, range roughly -1 to 1)
Returns:
mel: (1, n_mels, n_frames) numpy array
"""
# Ensure float32
audio = audio.astype(np.float32)
# Add padding to match transformers behavior
n_samples = len(audio)
n_frames = 1 + (n_samples - self.n_fft) // self.hop_length
# Compute STFT
stft = self._stft(audio)
# Compute power spectrogram
power = np.abs(stft) ** 2
# Apply mel filterbank
mel = np.dot(self.mel_filters, power)
# Log mel spectrogram (matching transformers)
mel = np.log10(np.maximum(mel, 1e-10))
# Add batch dimension
mel = mel[np.newaxis, :, :]
return mel
def _stft(self, audio):
"""Compute Short-Time Fourier Transform."""
# Pad audio
pad_length = self.n_fft // 2
audio_padded = np.pad(audio, (pad_length, pad_length), mode="reflect")
# Hann window
window = np.hanning(self.n_fft)
# Calculate number of frames
n_frames = 1 + (len(audio_padded) - self.n_fft) // self.hop_length
# Initialize STFT matrix
stft = np.zeros((self.n_fft // 2 + 1, n_frames), dtype=np.complex64)
# Compute STFT
for i in range(n_frames):
start = i * self.hop_length
frame = audio_padded[start : start + self.n_fft]
windowed = frame * window
fft = np.fft.rfft(windowed)
stft[:, i] = fft
return stft