| """ |
| True Streaming CoreML Diarization |
| |
| This script implements true streaming inference: |
| Audio chunks → CoreML Preprocessor → Feature Buffer → CoreML Main Model → Predictions |
| |
| Audio is processed incrementally, features are accumulated with proper context handling. |
| """ |
| import os |
| os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" |
|
|
| import torch |
| import numpy as np |
| import coremltools as ct |
| import librosa |
| import argparse |
| import math |
|
|
| |
| from nemo.collections.asr.models import SortformerEncLabelModel |
|
|
|
|
| |
| |
| |
| CONFIG = { |
| 'chunk_len': 4, |
| 'chunk_right_context': 1, |
| 'chunk_left_context': 2, |
| 'fifo_len': 63, |
| 'spkcache_len': 63, |
| 'spkcache_update_period': 50, |
| 'subsampling_factor': 8, |
| 'sample_rate': 16000, |
| |
| |
| 'chunk_frames': 56, |
| 'spkcache_input_len': 63, |
| 'fifo_input_len': 63, |
| |
| |
| 'preproc_audio_samples': 9200, |
| 'mel_window': 400, |
| 'mel_stride': 160, |
| } |
|
|
|
|
| def run_true_streaming(nemo_model, preproc_model, main_model, audio_path, config): |
| """ |
| True streaming inference: audio chunks → preproc → main model. |
| |
| Strategy: |
| 1. Process audio in chunks through CoreML preprocessor |
| 2. Accumulate features |
| 3. When enough features for a diarization chunk (with context), run main model |
| """ |
| modules = nemo_model.sortformer_modules |
| subsampling_factor = config['subsampling_factor'] |
| |
| |
| full_audio, sr = librosa.load(audio_path, sr=config['sample_rate'], mono=True) |
| total_samples = len(full_audio) |
| |
| print(f"Total audio samples: {total_samples}") |
| |
| |
| mel_window = config['mel_window'] |
| mel_stride = config['mel_stride'] |
| preproc_len = config['preproc_audio_samples'] |
| |
| |
| audio_hop = preproc_len - mel_window |
| |
| |
| all_features = [] |
| audio_offset = 0 |
| preproc_chunk_idx = 0 |
| |
| |
| print("Step 1: Extracting features via CoreML preprocessor...") |
| while audio_offset < total_samples: |
| |
| chunk_end = min(audio_offset + preproc_len, total_samples) |
| audio_chunk = full_audio[audio_offset:chunk_end] |
| actual_samples = len(audio_chunk) |
| |
| |
| if actual_samples < preproc_len: |
| audio_chunk = np.pad(audio_chunk, (0, preproc_len - actual_samples)) |
| |
| |
| preproc_inputs = { |
| "audio_signal": audio_chunk.reshape(1, -1).astype(np.float32), |
| "length": np.array([actual_samples], dtype=np.int32) |
| } |
| |
| preproc_out = preproc_model.predict(preproc_inputs) |
| feat_chunk = np.array(preproc_out["features"]) |
| feat_len = int(preproc_out["feature_lengths"][0]) |
| |
| |
| if preproc_chunk_idx == 0: |
| |
| valid_feats = feat_chunk[:, :, :feat_len] |
| else: |
| |
| overlap_frames = (mel_window - mel_stride) // mel_stride + 1 |
| valid_feats = feat_chunk[:, :, overlap_frames:feat_len] |
| |
| all_features.append(valid_feats) |
| |
| audio_offset += audio_hop |
| preproc_chunk_idx += 1 |
| |
| print(f"\r Processed audio chunk {preproc_chunk_idx}, features so far: {sum(f.shape[2] for f in all_features)}", end='') |
| |
| print() |
| |
| |
| full_features = np.concatenate(all_features, axis=2) |
| processed_signal = torch.from_numpy(full_features).float() |
| processed_signal_length = torch.tensor([full_features.shape[2]], dtype=torch.long) |
| |
| print(f"Total features extracted: {processed_signal.shape}") |
| |
| |
| print("Step 2: Running diarization streaming...") |
| |
| state = modules.init_streaming_state(batch_size=1, device='cpu') |
| all_preds = [] |
| |
| feat_len = processed_signal.shape[2] |
| chunk_len = modules.chunk_len |
| left_ctx = modules.chunk_left_context |
| right_ctx = modules.chunk_right_context |
| |
| stt_feat, end_feat, chunk_idx = 0, 0, 0 |
| |
| while end_feat < feat_len: |
| left_offset = min(left_ctx * subsampling_factor, stt_feat) |
| end_feat = min(stt_feat + chunk_len * subsampling_factor, feat_len) |
| right_offset = min(right_ctx * subsampling_factor, feat_len - end_feat) |
| |
| |
| chunk_feat = processed_signal[:, :, stt_feat - left_offset : end_feat + right_offset] |
| actual_len = chunk_feat.shape[2] |
| |
| |
| chunk_t = chunk_feat.transpose(1, 2) |
| |
| |
| if actual_len < config['chunk_frames']: |
| pad_len = config['chunk_frames'] - actual_len |
| chunk_in = torch.nn.functional.pad(chunk_t, (0, 0, 0, pad_len)) |
| else: |
| chunk_in = chunk_t[:, :config['chunk_frames'], :] |
| |
| |
| curr_spk_len = state.spkcache.shape[1] |
| curr_fifo_len = state.fifo.shape[1] |
| |
| current_spkcache = state.spkcache |
| if curr_spk_len < config['spkcache_input_len']: |
| current_spkcache = torch.nn.functional.pad( |
| current_spkcache, (0, 0, 0, config['spkcache_input_len'] - curr_spk_len) |
| ) |
| elif curr_spk_len > config['spkcache_input_len']: |
| current_spkcache = current_spkcache[:, :config['spkcache_input_len'], :] |
| |
| current_fifo = state.fifo |
| if curr_fifo_len < config['fifo_input_len']: |
| current_fifo = torch.nn.functional.pad( |
| current_fifo, (0, 0, 0, config['fifo_input_len'] - curr_fifo_len) |
| ) |
| elif curr_fifo_len > config['fifo_input_len']: |
| current_fifo = current_fifo[:, :config['fifo_input_len'], :] |
| |
| |
| coreml_inputs = { |
| "chunk": chunk_in.numpy().astype(np.float32), |
| "chunk_lengths": np.array([actual_len], dtype=np.int32), |
| "spkcache": current_spkcache.numpy().astype(np.float32), |
| "spkcache_lengths": np.array([curr_spk_len], dtype=np.int32), |
| "fifo": current_fifo.numpy().astype(np.float32), |
| "fifo_lengths": np.array([curr_fifo_len], dtype=np.int32) |
| } |
| |
| coreml_out = main_model.predict(coreml_inputs) |
| |
| pred_logits = torch.from_numpy(coreml_out["speaker_preds"]) |
| chunk_embs = torch.from_numpy(coreml_out["chunk_pre_encoder_embs"]) |
| chunk_emb_len = int(coreml_out["chunk_pre_encoder_lengths"][0]) |
| |
| chunk_embs = chunk_embs[:, :chunk_emb_len, :] |
| |
| lc = round(left_offset / subsampling_factor) |
| rc = math.ceil(right_offset / subsampling_factor) |
| |
| state, chunk_probs = modules.streaming_update( |
| streaming_state=state, |
| chunk=chunk_embs, |
| preds=pred_logits, |
| lc=lc, |
| rc=rc |
| ) |
| |
| all_preds.append(chunk_probs) |
| stt_feat = end_feat |
| chunk_idx += 1 |
| |
| print(f"\r Diarization chunk {chunk_idx}", end='') |
| |
| print() |
| |
| if len(all_preds) > 0: |
| return torch.cat(all_preds, dim=1) |
| return None |
|
|
|
|
| def run_reference(nemo_model, main_model, audio_path, config): |
| """ |
| Reference implementation using NeMo preprocessing. |
| """ |
| modules = nemo_model.sortformer_modules |
| subsampling_factor = modules.subsampling_factor |
| |
| |
| full_audio, _ = librosa.load(audio_path, sr=config['sample_rate'], mono=True) |
| audio_tensor = torch.from_numpy(full_audio).unsqueeze(0).float() |
| audio_length = torch.tensor([len(full_audio)], dtype=torch.long) |
| |
| |
| with torch.no_grad(): |
| processed_signal, processed_signal_length = nemo_model.process_signal( |
| audio_signal=audio_tensor, audio_signal_length=audio_length |
| ) |
| processed_signal = processed_signal[:, :, :processed_signal_length.max()] |
| |
| print(f"NeMo Preproc: features shape = {processed_signal.shape}") |
| |
| |
| state = modules.init_streaming_state(batch_size=1, device='cpu') |
| all_preds = [] |
| |
| feat_len = processed_signal.shape[2] |
| chunk_len = modules.chunk_len |
| left_ctx = modules.chunk_left_context |
| right_ctx = modules.chunk_right_context |
| |
| stt_feat, end_feat, chunk_idx = 0, 0, 0 |
| |
| while end_feat < feat_len: |
| left_offset = min(left_ctx * subsampling_factor, stt_feat) |
| end_feat = min(stt_feat + chunk_len * subsampling_factor, feat_len) |
| right_offset = min(right_ctx * subsampling_factor, feat_len - end_feat) |
| |
| chunk_feat = processed_signal[:, :, stt_feat - left_offset : end_feat + right_offset] |
| actual_len = chunk_feat.shape[2] |
| |
| chunk_t = chunk_feat.transpose(1, 2) |
| |
| if actual_len < config['chunk_frames']: |
| pad_len = config['chunk_frames'] - actual_len |
| chunk_in = torch.nn.functional.pad(chunk_t, (0, 0, 0, pad_len)) |
| else: |
| chunk_in = chunk_t[:, :config['chunk_frames'], :] |
| |
| curr_spk_len = state.spkcache.shape[1] |
| curr_fifo_len = state.fifo.shape[1] |
| |
| current_spkcache = state.spkcache |
| if curr_spk_len < config['spkcache_input_len']: |
| current_spkcache = torch.nn.functional.pad( |
| current_spkcache, (0, 0, 0, config['spkcache_input_len'] - curr_spk_len) |
| ) |
| elif curr_spk_len > config['spkcache_input_len']: |
| current_spkcache = current_spkcache[:, :config['spkcache_input_len'], :] |
| |
| current_fifo = state.fifo |
| if curr_fifo_len < config['fifo_input_len']: |
| current_fifo = torch.nn.functional.pad( |
| current_fifo, (0, 0, 0, config['fifo_input_len'] - curr_fifo_len) |
| ) |
| elif curr_fifo_len > config['fifo_input_len']: |
| current_fifo = current_fifo[:, :config['fifo_input_len'], :] |
| |
| coreml_inputs = { |
| "chunk": chunk_in.numpy().astype(np.float32), |
| "chunk_lengths": np.array([actual_len], dtype=np.int32), |
| "spkcache": current_spkcache.numpy().astype(np.float32), |
| "spkcache_lengths": np.array([curr_spk_len], dtype=np.int32), |
| "fifo": current_fifo.numpy().astype(np.float32), |
| "fifo_lengths": np.array([curr_fifo_len], dtype=np.int32) |
| } |
| |
| coreml_out = main_model.predict(coreml_inputs) |
| |
| pred_logits = torch.from_numpy(coreml_out["speaker_preds"]) |
| chunk_embs = torch.from_numpy(coreml_out["chunk_pre_encoder_embs"]) |
| chunk_emb_len = int(coreml_out["chunk_pre_encoder_lengths"][0]) |
| |
| chunk_embs = chunk_embs[:, :chunk_emb_len, :] |
| |
| lc = round(left_offset / subsampling_factor) |
| rc = math.ceil(right_offset / subsampling_factor) |
| |
| state, chunk_probs = modules.streaming_update( |
| streaming_state=state, |
| chunk=chunk_embs, |
| preds=pred_logits, |
| lc=lc, |
| rc=rc |
| ) |
| |
| all_preds.append(chunk_probs) |
| stt_feat = end_feat |
| chunk_idx += 1 |
| |
| if len(all_preds) > 0: |
| return torch.cat(all_preds, dim=1) |
| return None |
|
|
|
|
| def validate(model_name, coreml_dir, audio_path): |
| """ |
| Validate true streaming against NeMo preprocessing. |
| """ |
| print("=" * 70) |
| print("VALIDATION: True Streaming vs NeMo Preprocessing") |
| print("=" * 70) |
| |
| |
| print(f"\nLoading NeMo Model: {model_name}") |
| nemo_model = SortformerEncLabelModel.from_pretrained(model_name, map_location="cpu") |
| nemo_model.eval() |
| |
| |
| modules = nemo_model.sortformer_modules |
| modules.chunk_len = CONFIG['chunk_len'] |
| modules.chunk_right_context = CONFIG['chunk_right_context'] |
| modules.chunk_left_context = CONFIG['chunk_left_context'] |
| modules.fifo_len = CONFIG['fifo_len'] |
| modules.spkcache_len = CONFIG['spkcache_len'] |
| modules.spkcache_update_period = CONFIG['spkcache_update_period'] |
| |
| |
| if hasattr(nemo_model.preprocessor, 'featurizer'): |
| nemo_model.preprocessor.featurizer.dither = 0.0 |
| nemo_model.preprocessor.featurizer.pad_to = 0 |
| |
| print(f"Config: chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, " |
| f"right_ctx={modules.chunk_right_context}") |
| |
| |
| print(f"Loading CoreML Models from {coreml_dir}...") |
| preproc_model = ct.models.MLModel( |
| os.path.join(coreml_dir, "SortformerPreprocessor.mlpackage"), |
| compute_units=ct.ComputeUnit.CPU_ONLY |
| ) |
| main_model = ct.models.MLModel( |
| os.path.join(coreml_dir, "Sortformer16.mlpackage"), |
| compute_units=ct.ComputeUnit.CPU_ONLY |
| ) |
| |
| |
| print("\n" + "=" * 70) |
| print("TEST 1: NeMo Preprocessing + CoreML Inference (Reference)") |
| print("=" * 70) |
| |
| ref_probs = run_reference(nemo_model, main_model, audio_path, CONFIG) |
| if ref_probs is not None: |
| ref_probs_np = ref_probs.squeeze(0).detach().cpu().numpy() |
| print(f"Reference Probs Shape: {ref_probs_np.shape}") |
| else: |
| print("Reference inference failed!") |
| return |
| |
| |
| print("\n" + "=" * 70) |
| print("TEST 2: True Streaming (Audio → CoreML Preproc → CoreML Main)") |
| print("=" * 70) |
| |
| streaming_probs = run_true_streaming(nemo_model, preproc_model, main_model, audio_path, CONFIG) |
| |
| if streaming_probs is not None: |
| streaming_probs_np = streaming_probs.squeeze(0).detach().cpu().numpy() |
| print(f"Streaming Probs Shape: {streaming_probs_np.shape}") |
| |
| |
| min_len = min(ref_probs_np.shape[0], streaming_probs_np.shape[0]) |
| diff = np.abs(ref_probs_np[:min_len] - streaming_probs_np[:min_len]) |
| print(f"\nLength: ref={ref_probs_np.shape[0]}, streaming={streaming_probs_np.shape[0]}") |
| print(f"Mean Absolute Error: {np.mean(diff):.8f}") |
| print(f"Max Absolute Error: {np.max(diff):.8f}") |
| |
| if np.max(diff) < 0.01: |
| print("\n✅ SUCCESS: True streaming matches reference!") |
| else: |
| print("\n⚠️ Errors exceed tolerance") |
| else: |
| print("True streaming inference produced no output!") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_name", default="nvidia/diar_streaming_sortformer_4spk-v2.1") |
| parser.add_argument("--coreml_dir", default="coreml_models") |
| parser.add_argument("--audio_path", default="audio.wav") |
| args = parser.parse_args() |
| |
| validate(args.model_name, args.coreml_dir, args.audio_path) |
|
|