|
|
|
|
|
"""Test CoreML inference for Nemotron Streaming 0.6B on LibriSpeech test-clean.""" |
|
|
import glob |
|
|
import json |
|
|
import re |
|
|
from pathlib import Path |
|
|
|
|
|
import coremltools as ct |
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
|
|
|
|
|
|
def load_ground_truth(librispeech_path: str) -> dict: |
|
|
"""Load all ground truth transcriptions.""" |
|
|
gt = {} |
|
|
for trans_file in glob.glob(f"{librispeech_path}/**/*.trans.txt", recursive=True): |
|
|
with open(trans_file) as f: |
|
|
for line in f: |
|
|
parts = line.strip().split(" ", 1) |
|
|
if len(parts) == 2: |
|
|
file_id, text = parts |
|
|
gt[file_id] = text.lower() |
|
|
return gt |
|
|
|
|
|
|
|
|
def normalize_text(text: str) -> str: |
|
|
"""Normalize text for WER calculation.""" |
|
|
text = re.sub(r'[^\w\s]', '', text) |
|
|
return ' '.join(text.lower().split()) |
|
|
|
|
|
|
|
|
def compute_wer(reference: str, hypothesis: str) -> tuple: |
|
|
"""Compute WER between reference and hypothesis.""" |
|
|
ref_words = normalize_text(reference).split() |
|
|
hyp_words = normalize_text(hypothesis).split() |
|
|
|
|
|
d = np.zeros((len(ref_words) + 1, len(hyp_words) + 1), dtype=np.uint32) |
|
|
for i in range(len(ref_words) + 1): |
|
|
d[i, 0] = i |
|
|
for j in range(len(hyp_words) + 1): |
|
|
d[0, j] = j |
|
|
|
|
|
for i in range(1, len(ref_words) + 1): |
|
|
for j in range(1, len(hyp_words) + 1): |
|
|
if ref_words[i-1] == hyp_words[j-1]: |
|
|
d[i, j] = d[i-1, j-1] |
|
|
else: |
|
|
d[i, j] = min(d[i-1, j] + 1, d[i, j-1] + 1, d[i-1, j-1] + 1) |
|
|
|
|
|
errors = d[len(ref_words), len(hyp_words)] |
|
|
return errors, len(ref_words) |
|
|
|
|
|
|
|
|
class NemotronCoreMLInference: |
|
|
"""CoreML inference for Nemotron Streaming.""" |
|
|
|
|
|
def __init__(self, model_dir: str): |
|
|
model_dir = Path(model_dir) |
|
|
|
|
|
|
|
|
with open(model_dir / "metadata.json") as f: |
|
|
self.metadata = json.load(f) |
|
|
|
|
|
|
|
|
with open(model_dir / "tokenizer.json") as f: |
|
|
self.tokenizer = json.load(f) |
|
|
|
|
|
print("Loading CoreML models...") |
|
|
self.preprocessor = ct.models.MLModel(str(model_dir / "preprocessor.mlpackage")) |
|
|
self.encoder = ct.models.MLModel(str(model_dir / "encoder.mlpackage")) |
|
|
self.decoder = ct.models.MLModel(str(model_dir / "decoder.mlpackage")) |
|
|
self.joint = ct.models.MLModel(str(model_dir / "joint.mlpackage")) |
|
|
print("Models loaded!") |
|
|
|
|
|
self.sample_rate = self.metadata["sample_rate"] |
|
|
self.chunk_mel_frames = self.metadata["chunk_mel_frames"] |
|
|
self.pre_encode_cache = self.metadata["pre_encode_cache"] |
|
|
self.total_mel_frames = self.metadata["total_mel_frames"] |
|
|
self.blank_idx = self.metadata["blank_idx"] |
|
|
self.vocab_size = self.metadata["vocab_size"] |
|
|
self.decoder_hidden = self.metadata["decoder_hidden"] |
|
|
self.decoder_layers = self.metadata["decoder_layers"] |
|
|
|
|
|
|
|
|
self.cache_channel_shape = self.metadata["cache_channel_shape"] |
|
|
self.cache_time_shape = self.metadata["cache_time_shape"] |
|
|
|
|
|
def _get_initial_cache(self): |
|
|
"""Get initial encoder cache state.""" |
|
|
cache_channel = np.zeros(self.cache_channel_shape, dtype=np.float32) |
|
|
cache_time = np.zeros(self.cache_time_shape, dtype=np.float32) |
|
|
cache_len = np.array([0], dtype=np.int32) |
|
|
return cache_channel, cache_time, cache_len |
|
|
|
|
|
def _get_initial_decoder_state(self): |
|
|
"""Get initial decoder LSTM state.""" |
|
|
h = np.zeros((self.decoder_layers, 1, self.decoder_hidden), dtype=np.float32) |
|
|
c = np.zeros((self.decoder_layers, 1, self.decoder_hidden), dtype=np.float32) |
|
|
return h, c |
|
|
|
|
|
def _decode_tokens(self, tokens: list) -> str: |
|
|
"""Decode token IDs to text.""" |
|
|
text_parts = [] |
|
|
for tok in tokens: |
|
|
if tok < self.vocab_size and tok != self.blank_idx: |
|
|
text_parts.append(self.tokenizer.get(str(tok), "")) |
|
|
|
|
|
text = "".join(text_parts) |
|
|
text = text.replace("▁", " ").strip() |
|
|
return text |
|
|
|
|
|
def transcribe(self, audio: np.ndarray) -> str: |
|
|
"""Transcribe audio using streaming CoreML inference.""" |
|
|
|
|
|
audio = audio.astype(np.float32) |
|
|
if audio.ndim == 1: |
|
|
audio = audio.reshape(1, -1) |
|
|
|
|
|
audio_len = np.array([audio.shape[1]], dtype=np.int32) |
|
|
|
|
|
|
|
|
preproc_out = self.preprocessor.predict({ |
|
|
"audio": audio, |
|
|
"audio_length": audio_len |
|
|
}) |
|
|
mel = preproc_out["mel"] |
|
|
mel_len = preproc_out["mel_length"][0] |
|
|
|
|
|
|
|
|
cache_channel, cache_time, cache_len = self._get_initial_cache() |
|
|
h, c = self._get_initial_decoder_state() |
|
|
|
|
|
|
|
|
last_token = self.blank_idx |
|
|
all_tokens = [] |
|
|
|
|
|
|
|
|
chunk_start = 0 |
|
|
mel_total_frames = mel.shape[2] |
|
|
|
|
|
while chunk_start < mel_total_frames: |
|
|
|
|
|
if chunk_start == 0: |
|
|
|
|
|
chunk_end = min(self.chunk_mel_frames, mel_total_frames) |
|
|
chunk_mel = mel[:, :, :chunk_end] |
|
|
|
|
|
if chunk_mel.shape[2] < self.total_mel_frames: |
|
|
pad_width = self.total_mel_frames - chunk_mel.shape[2] |
|
|
chunk_mel = np.pad(chunk_mel, ((0,0), (0,0), (pad_width, 0)), mode='constant') |
|
|
else: |
|
|
|
|
|
cache_start = max(0, chunk_start - self.pre_encode_cache) |
|
|
chunk_end = min(chunk_start + self.chunk_mel_frames, mel_total_frames) |
|
|
chunk_mel = mel[:, :, cache_start:chunk_end] |
|
|
|
|
|
if chunk_mel.shape[2] < self.total_mel_frames: |
|
|
pad_width = self.total_mel_frames - chunk_mel.shape[2] |
|
|
chunk_mel = np.pad(chunk_mel, ((0,0), (0,0), (0, pad_width)), mode='constant') |
|
|
|
|
|
chunk_mel_len = np.array([chunk_mel.shape[2]], dtype=np.int32) |
|
|
|
|
|
|
|
|
enc_out = self.encoder.predict({ |
|
|
"mel": chunk_mel.astype(np.float32), |
|
|
"mel_length": chunk_mel_len, |
|
|
"cache_channel": cache_channel, |
|
|
"cache_time": cache_time, |
|
|
"cache_len": cache_len |
|
|
}) |
|
|
|
|
|
encoded = enc_out["encoded"] |
|
|
cache_channel = enc_out["cache_channel_out"] |
|
|
cache_time = enc_out["cache_time_out"] |
|
|
cache_len = enc_out["cache_len_out"] |
|
|
|
|
|
|
|
|
num_enc_frames = encoded.shape[2] |
|
|
for t in range(num_enc_frames): |
|
|
enc_step = encoded[:, :, t:t+1] |
|
|
|
|
|
|
|
|
for _ in range(10): |
|
|
|
|
|
token_input = np.array([[last_token]], dtype=np.int32) |
|
|
token_len = np.array([1], dtype=np.int32) |
|
|
|
|
|
dec_out = self.decoder.predict({ |
|
|
"token": token_input, |
|
|
"token_length": token_len, |
|
|
"h_in": h, |
|
|
"c_in": c |
|
|
}) |
|
|
|
|
|
decoder_out = dec_out["decoder_out"] |
|
|
h_new = dec_out["h_out"] |
|
|
c_new = dec_out["c_out"] |
|
|
|
|
|
|
|
|
joint_out = self.joint.predict({ |
|
|
"encoder": enc_step.astype(np.float32), |
|
|
"decoder": decoder_out[:, :, :1].astype(np.float32) |
|
|
}) |
|
|
|
|
|
logits = joint_out["logits"] |
|
|
pred_token = int(np.argmax(logits[0, 0, 0, :])) |
|
|
|
|
|
if pred_token == self.blank_idx: |
|
|
|
|
|
break |
|
|
else: |
|
|
|
|
|
all_tokens.append(pred_token) |
|
|
last_token = pred_token |
|
|
h = h_new |
|
|
c = c_new |
|
|
|
|
|
chunk_start += self.chunk_mel_frames |
|
|
|
|
|
return self._decode_tokens(all_tokens) |
|
|
|
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model-dir", type=str, default="nemotron_coreml") |
|
|
parser.add_argument("--dataset", type=str, default="datasets/LibriSpeech/test-clean") |
|
|
parser.add_argument("--num-files", type=int, default=10) |
|
|
args = parser.parse_args() |
|
|
|
|
|
print("=" * 70) |
|
|
print("NEMOTRON COREML INFERENCE TEST") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
print(f"\nLoading ground truth from {args.dataset}...") |
|
|
gt = load_ground_truth(args.dataset) |
|
|
print(f"Loaded {len(gt)} transcriptions") |
|
|
|
|
|
|
|
|
audio_files = sorted(glob.glob(f"{args.dataset}/**/*.flac", recursive=True))[:args.num_files] |
|
|
print(f"Testing on {len(audio_files)} files") |
|
|
|
|
|
|
|
|
print() |
|
|
inference = NemotronCoreMLInference(args.model_dir) |
|
|
|
|
|
|
|
|
print("\n[COREML STREAMING]") |
|
|
total_errors = 0 |
|
|
total_words = 0 |
|
|
|
|
|
for i, audio_path in enumerate(audio_files): |
|
|
file_id = Path(audio_path).stem |
|
|
print(f" [{i+1}/{len(audio_files)}] {file_id}", end=" ", flush=True) |
|
|
|
|
|
audio, sr = sf.read(audio_path, dtype="float32") |
|
|
hyp = inference.transcribe(audio) |
|
|
|
|
|
if file_id in gt: |
|
|
errors, words = compute_wer(gt[file_id], hyp) |
|
|
total_errors += errors |
|
|
total_words += words |
|
|
current_wer = 100 * total_errors / total_words |
|
|
print(f"-> {errors} errs, WER so far: {current_wer:.2f}%") |
|
|
if errors > 0: |
|
|
print(f" REF: {gt[file_id][:80]}...") |
|
|
print(f" HYP: {hyp[:80]}...") |
|
|
else: |
|
|
print("-> (no ground truth)") |
|
|
|
|
|
wer = 100 * total_errors / total_words if total_words > 0 else 0 |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("SUMMARY") |
|
|
print("=" * 70) |
|
|
print(f"Files tested: {len(audio_files)}") |
|
|
print(f"CoreML WER: {wer:.2f}%") |
|
|
print(f"PyTorch WER: ~1.88% (reference)") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|