| |
| """ |
| Example usage of Cohere Transcribe Cache-External CoreML models. |
| |
| Requirements: |
| pip install coremltools numpy librosa soundfile sentencepiece |
| """ |
|
|
| import argparse |
| from pathlib import Path |
| import numpy as np |
| import coremltools as ct |
| import soundfile as sf |
| import librosa |
| import sentencepiece as spm |
|
|
| |
| SAMPLE_RATE = 16000 |
| N_MELS = 128 |
| HOP_LENGTH = 160 |
| N_FFT = 400 |
| MAX_FRAMES = 3500 |
| MAX_SEQ_LEN = 108 |
|
|
| |
| START_TOKEN = 4 |
| EOS_TOKEN = 3 |
|
|
|
|
| def compute_mel_spectrogram(audio, sr=SAMPLE_RATE): |
| """Compute mel spectrogram matching Cohere's preprocessing.""" |
| if sr != SAMPLE_RATE: |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE) |
|
|
| mel = librosa.feature.melspectrogram( |
| y=audio, |
| sr=SAMPLE_RATE, |
| n_fft=N_FFT, |
| hop_length=HOP_LENGTH, |
| n_mels=N_MELS, |
| fmin=0, |
| fmax=8000, |
| ) |
|
|
| mel = librosa.power_to_db(mel, ref=np.max) |
| mel = (mel + 80) / 80 |
| mel = np.clip(mel, -1, 1) |
|
|
| return mel |
|
|
|
|
| def pad_mel(mel, target_frames=MAX_FRAMES): |
| """Pad mel spectrogram to target frames.""" |
| n_mels, n_frames = mel.shape |
|
|
| if n_frames >= target_frames: |
| return mel[:, :target_frames], n_frames |
|
|
| padded = np.zeros((n_mels, target_frames), dtype=np.float32) |
| padded[:, :n_frames] = mel |
|
|
| return padded, n_frames |
|
|
|
|
| def create_attention_mask(seq_len): |
| """Create causal attention mask for given sequence length.""" |
| return np.zeros((1, 1, 1, seq_len), dtype=np.float32) |
|
|
|
|
| def transcribe(audio_path, encoder, decoder, vocabulary): |
| """Transcribe audio using cache-external decoder.""" |
| print(f"Transcribing: {audio_path}") |
|
|
| |
| audio, sr = sf.read(audio_path) |
| duration = len(audio) / sr |
| print(f" Duration: {duration:.2f}s") |
|
|
| |
| mel = compute_mel_spectrogram(audio, sr) |
| padded_mel, actual_frames = pad_mel(mel) |
| print(f" Mel frames: {actual_frames} (padded to {MAX_FRAMES})") |
|
|
| |
| encoder_input = { |
| "input_features": np.expand_dims(padded_mel, axis=0).astype(np.float32), |
| "feature_length": np.array([actual_frames], dtype=np.int32), |
| } |
| encoder_output = encoder.predict(encoder_input) |
| encoder_hidden = encoder_output["hidden_states"] |
| print(f" Encoder output shape: {encoder_hidden.shape}") |
|
|
| |
| k_caches = [np.zeros((1, 8, MAX_SEQ_LEN, 128), dtype=np.float32) for _ in range(8)] |
| v_caches = [np.zeros((1, 8, MAX_SEQ_LEN, 128), dtype=np.float32) for _ in range(8)] |
|
|
| |
| encoder_seq_len = encoder_hidden.shape[1] |
| cross_mask = np.ones((1, 1, 1, encoder_seq_len), dtype=np.float32) |
|
|
| |
| tokens = [] |
| current_token = START_TOKEN |
|
|
| for step in range(MAX_SEQ_LEN): |
| |
| input_dict = { |
| "input_id": np.array([[current_token]], dtype=np.int32), |
| "position_id": np.array([[step]], dtype=np.int32), |
| "encoder_hidden_states": encoder_hidden.astype(np.float32), |
| "cross_attention_mask": cross_mask, |
| "attention_mask": create_attention_mask(step + 1), |
| } |
|
|
| |
| for i in range(8): |
| input_dict[f"k_cache_{i}"] = k_caches[i] |
| input_dict[f"v_cache_{i}"] = v_caches[i] |
|
|
| |
| output = decoder.predict(input_dict) |
|
|
| |
| logits = output["logits"] |
| next_token = int(np.argmax(logits[0])) |
|
|
| |
| for i in range(8): |
| k_caches[i] = output[f"k_cache_{i}_out"] |
| v_caches[i] = output[f"v_cache_{i}_out"] |
|
|
| |
| if next_token == EOS_TOKEN: |
| print(f" EOS detected at step {step}") |
| break |
|
|
| tokens.append(next_token) |
| current_token = next_token |
|
|
| print(f" Generated {len(tokens)} tokens") |
|
|
| |
| text_tokens = [] |
| for token_id in tokens: |
| if token_id <= 4 or token_id == EOS_TOKEN or token_id >= len(vocabulary): |
| continue |
| token = vocabulary[token_id] |
| if token.startswith("<|"): |
| continue |
| text_tokens.append(token) |
|
|
| text = "".join(text_tokens).replace("▁", " ").strip() |
|
|
| return text |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Transcribe audio with Cohere Cache-External") |
| parser.add_argument("audio", help="Path to audio file (.wav, .flac, etc.)") |
| parser.add_argument("--encoder", default="cohere_encoder.mlpackage", help="Path to encoder") |
| parser.add_argument("--decoder", default="cohere_decoder_cache_external.mlpackage", help="Path to decoder") |
| parser.add_argument("--tokenizer", default="tokenizer.model", help="Path to tokenizer") |
| args = parser.parse_args() |
|
|
| print("=" * 70) |
| print("Cohere Transcribe - Cache-External Decoder") |
| print("=" * 70) |
| print() |
|
|
| |
| print("Loading models...") |
| encoder = ct.models.MLModel(args.encoder) |
| decoder = ct.models.MLModel(args.decoder) |
| print(" ✓ Models loaded") |
|
|
| |
| print("Loading tokenizer...") |
| sp = spm.SentencePieceProcessor() |
| sp.load(args.tokenizer) |
| vocabulary = [sp.id_to_piece(i) for i in range(sp.get_piece_size())] |
| print(f" ✓ Loaded {len(vocabulary)} tokens") |
| print() |
|
|
| |
| text = transcribe(args.audio, encoder, decoder, vocabulary) |
|
|
| print() |
| print("=" * 70) |
| print("TRANSCRIPTION") |
| print("=" * 70) |
| print(text) |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|