File size: 5,915 Bytes
08bf7d0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | #!/usr/bin/env python3
"""
Example usage of Cohere Transcribe Cache-External CoreML models.
Requirements:
pip install coremltools numpy librosa soundfile sentencepiece
"""
import argparse
from pathlib import Path
import numpy as np
import coremltools as ct
import soundfile as sf
import librosa
import sentencepiece as spm
# Cohere config
SAMPLE_RATE = 16000
N_MELS = 128
HOP_LENGTH = 160
N_FFT = 400
MAX_FRAMES = 3500
MAX_SEQ_LEN = 108
# Special tokens - CRITICAL: Use correct EOS token!
START_TOKEN = 4
EOS_TOKEN = 3 # <|endoftext|> - verified from model.generation_config.eos_token_id
def compute_mel_spectrogram(audio, sr=SAMPLE_RATE):
"""Compute mel spectrogram matching Cohere's preprocessing."""
if sr != SAMPLE_RATE:
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
mel = librosa.feature.melspectrogram(
y=audio,
sr=SAMPLE_RATE,
n_fft=N_FFT,
hop_length=HOP_LENGTH,
n_mels=N_MELS,
fmin=0,
fmax=8000,
)
mel = librosa.power_to_db(mel, ref=np.max)
mel = (mel + 80) / 80
mel = np.clip(mel, -1, 1)
return mel
def pad_mel(mel, target_frames=MAX_FRAMES):
"""Pad mel spectrogram to target frames."""
n_mels, n_frames = mel.shape
if n_frames >= target_frames:
return mel[:, :target_frames], n_frames
padded = np.zeros((n_mels, target_frames), dtype=np.float32)
padded[:, :n_frames] = mel
return padded, n_frames
def create_attention_mask(seq_len):
"""Create causal attention mask for given sequence length."""
return np.zeros((1, 1, 1, seq_len), dtype=np.float32)
def transcribe(audio_path, encoder, decoder, vocabulary):
"""Transcribe audio using cache-external decoder."""
print(f"Transcribing: {audio_path}")
# 1. Load audio
audio, sr = sf.read(audio_path)
duration = len(audio) / sr
print(f" Duration: {duration:.2f}s")
# 2. Compute mel spectrogram
mel = compute_mel_spectrogram(audio, sr)
padded_mel, actual_frames = pad_mel(mel)
print(f" Mel frames: {actual_frames} (padded to {MAX_FRAMES})")
# 3. Encode
encoder_input = {
"input_features": np.expand_dims(padded_mel, axis=0).astype(np.float32),
"feature_length": np.array([actual_frames], dtype=np.int32),
}
encoder_output = encoder.predict(encoder_input)
encoder_hidden = encoder_output["hidden_states"]
print(f" Encoder output shape: {encoder_hidden.shape}")
# 4. Initialize caches (8 layers × K/V)
k_caches = [np.zeros((1, 8, MAX_SEQ_LEN, 128), dtype=np.float32) for _ in range(8)]
v_caches = [np.zeros((1, 8, MAX_SEQ_LEN, 128), dtype=np.float32) for _ in range(8)]
# Cross-attention mask (all ones - attend to all encoder positions)
encoder_seq_len = encoder_hidden.shape[1]
cross_mask = np.ones((1, 1, 1, encoder_seq_len), dtype=np.float32)
# 5. Decode with cache-external pattern
tokens = []
current_token = START_TOKEN
for step in range(MAX_SEQ_LEN):
# Build decoder input
input_dict = {
"input_id": np.array([[current_token]], dtype=np.int32),
"position_id": np.array([[step]], dtype=np.int32),
"encoder_hidden_states": encoder_hidden.astype(np.float32),
"cross_attention_mask": cross_mask,
"attention_mask": create_attention_mask(step + 1), # Grows each step!
}
# Add all K/V caches to input
for i in range(8):
input_dict[f"k_cache_{i}"] = k_caches[i]
input_dict[f"v_cache_{i}"] = v_caches[i]
# Run decoder (single step)
output = decoder.predict(input_dict)
# Sample next token (greedy)
logits = output["logits"]
next_token = int(np.argmax(logits[0]))
# Update caches with outputs from model
for i in range(8):
k_caches[i] = output[f"k_cache_{i}_out"]
v_caches[i] = output[f"v_cache_{i}_out"]
# Check for EOS (end of sequence)
if next_token == EOS_TOKEN:
print(f" EOS detected at step {step}")
break
tokens.append(next_token)
current_token = next_token
print(f" Generated {len(tokens)} tokens")
# 6. Detokenize
text_tokens = []
for token_id in tokens:
if token_id <= 4 or token_id == EOS_TOKEN or token_id >= len(vocabulary):
continue
token = vocabulary[token_id]
if token.startswith("<|"):
continue
text_tokens.append(token)
text = "".join(text_tokens).replace("▁", " ").strip()
return text
def main():
parser = argparse.ArgumentParser(description="Transcribe audio with Cohere Cache-External")
parser.add_argument("audio", help="Path to audio file (.wav, .flac, etc.)")
parser.add_argument("--encoder", default="cohere_encoder.mlpackage", help="Path to encoder")
parser.add_argument("--decoder", default="cohere_decoder_cache_external.mlpackage", help="Path to decoder")
parser.add_argument("--tokenizer", default="tokenizer.model", help="Path to tokenizer")
args = parser.parse_args()
print("=" * 70)
print("Cohere Transcribe - Cache-External Decoder")
print("=" * 70)
print()
# Load models
print("Loading models...")
encoder = ct.models.MLModel(args.encoder)
decoder = ct.models.MLModel(args.decoder)
print(" ✓ Models loaded")
# Load vocabulary
print("Loading tokenizer...")
sp = spm.SentencePieceProcessor()
sp.load(args.tokenizer)
vocabulary = [sp.id_to_piece(i) for i in range(sp.get_piece_size())]
print(f" ✓ Loaded {len(vocabulary)} tokens")
print()
# Transcribe
text = transcribe(args.audio, encoder, decoder, vocabulary)
print()
print("=" * 70)
print("TRANSCRIPTION")
print("=" * 70)
print(text)
print()
if __name__ == "__main__":
main()
|