|
|
""" |
|
|
True Streaming CoreML Diarization |
|
|
|
|
|
This script implements true streaming inference: |
|
|
Audio chunks → CoreML Preprocessor → Feature Buffer → CoreML Main Model → Predictions |
|
|
|
|
|
Audio is processed incrementally, features are accumulated with proper context handling. |
|
|
""" |
|
|
import os |
|
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import coremltools as ct |
|
|
import librosa |
|
|
import argparse |
|
|
import math |
|
|
|
|
|
|
|
|
from nemo.collections.asr.models import SortformerEncLabelModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CONFIG = { |
|
|
'chunk_len': 4, |
|
|
'chunk_right_context': 1, |
|
|
'chunk_left_context': 2, |
|
|
'fifo_len': 63, |
|
|
'spkcache_len': 63, |
|
|
'spkcache_update_period': 50, |
|
|
'subsampling_factor': 8, |
|
|
'sample_rate': 16000, |
|
|
|
|
|
|
|
|
'chunk_frames': 56, |
|
|
'spkcache_input_len': 63, |
|
|
'fifo_input_len': 63, |
|
|
|
|
|
|
|
|
'preproc_audio_samples': 9200, |
|
|
'mel_window': 400, |
|
|
'mel_stride': 160, |
|
|
} |
|
|
|
|
|
|
|
|
def run_true_streaming(nemo_model, preproc_model, main_model, audio_path, config): |
|
|
""" |
|
|
True streaming inference: audio chunks → preproc → main model. |
|
|
|
|
|
Strategy: |
|
|
1. Process audio in chunks through CoreML preprocessor |
|
|
2. Accumulate features |
|
|
3. When enough features for a diarization chunk (with context), run main model |
|
|
""" |
|
|
modules = nemo_model.sortformer_modules |
|
|
subsampling_factor = config['subsampling_factor'] |
|
|
|
|
|
|
|
|
full_audio, sr = librosa.load(audio_path, sr=config['sample_rate'], mono=True) |
|
|
total_samples = len(full_audio) |
|
|
|
|
|
print(f"Total audio samples: {total_samples}") |
|
|
|
|
|
|
|
|
mel_window = config['mel_window'] |
|
|
mel_stride = config['mel_stride'] |
|
|
preproc_len = config['preproc_audio_samples'] |
|
|
|
|
|
|
|
|
audio_hop = preproc_len - mel_window |
|
|
|
|
|
|
|
|
all_features = [] |
|
|
audio_offset = 0 |
|
|
preproc_chunk_idx = 0 |
|
|
|
|
|
|
|
|
print("Step 1: Extracting features via CoreML preprocessor...") |
|
|
while audio_offset < total_samples: |
|
|
|
|
|
chunk_end = min(audio_offset + preproc_len, total_samples) |
|
|
audio_chunk = full_audio[audio_offset:chunk_end] |
|
|
actual_samples = len(audio_chunk) |
|
|
|
|
|
|
|
|
if actual_samples < preproc_len: |
|
|
audio_chunk = np.pad(audio_chunk, (0, preproc_len - actual_samples)) |
|
|
|
|
|
|
|
|
preproc_inputs = { |
|
|
"audio_signal": audio_chunk.reshape(1, -1).astype(np.float32), |
|
|
"length": np.array([actual_samples], dtype=np.int32) |
|
|
} |
|
|
|
|
|
preproc_out = preproc_model.predict(preproc_inputs) |
|
|
feat_chunk = np.array(preproc_out["features"]) |
|
|
feat_len = int(preproc_out["feature_lengths"][0]) |
|
|
|
|
|
|
|
|
if preproc_chunk_idx == 0: |
|
|
|
|
|
valid_feats = feat_chunk[:, :, :feat_len] |
|
|
else: |
|
|
|
|
|
overlap_frames = (mel_window - mel_stride) // mel_stride + 1 |
|
|
valid_feats = feat_chunk[:, :, overlap_frames:feat_len] |
|
|
|
|
|
all_features.append(valid_feats) |
|
|
|
|
|
audio_offset += audio_hop |
|
|
preproc_chunk_idx += 1 |
|
|
|
|
|
print(f"\r Processed audio chunk {preproc_chunk_idx}, features so far: {sum(f.shape[2] for f in all_features)}", end='') |
|
|
|
|
|
print() |
|
|
|
|
|
|
|
|
full_features = np.concatenate(all_features, axis=2) |
|
|
processed_signal = torch.from_numpy(full_features).float() |
|
|
processed_signal_length = torch.tensor([full_features.shape[2]], dtype=torch.long) |
|
|
|
|
|
print(f"Total features extracted: {processed_signal.shape}") |
|
|
|
|
|
|
|
|
print("Step 2: Running diarization streaming...") |
|
|
|
|
|
state = modules.init_streaming_state(batch_size=1, device='cpu') |
|
|
all_preds = [] |
|
|
|
|
|
feat_len = processed_signal.shape[2] |
|
|
chunk_len = modules.chunk_len |
|
|
left_ctx = modules.chunk_left_context |
|
|
right_ctx = modules.chunk_right_context |
|
|
|
|
|
stt_feat, end_feat, chunk_idx = 0, 0, 0 |
|
|
|
|
|
while end_feat < feat_len: |
|
|
left_offset = min(left_ctx * subsampling_factor, stt_feat) |
|
|
end_feat = min(stt_feat + chunk_len * subsampling_factor, feat_len) |
|
|
right_offset = min(right_ctx * subsampling_factor, feat_len - end_feat) |
|
|
|
|
|
|
|
|
chunk_feat = processed_signal[:, :, stt_feat - left_offset : end_feat + right_offset] |
|
|
actual_len = chunk_feat.shape[2] |
|
|
|
|
|
|
|
|
chunk_t = chunk_feat.transpose(1, 2) |
|
|
|
|
|
|
|
|
if actual_len < config['chunk_frames']: |
|
|
pad_len = config['chunk_frames'] - actual_len |
|
|
chunk_in = torch.nn.functional.pad(chunk_t, (0, 0, 0, pad_len)) |
|
|
else: |
|
|
chunk_in = chunk_t[:, :config['chunk_frames'], :] |
|
|
|
|
|
|
|
|
curr_spk_len = state.spkcache.shape[1] |
|
|
curr_fifo_len = state.fifo.shape[1] |
|
|
|
|
|
current_spkcache = state.spkcache |
|
|
if curr_spk_len < config['spkcache_input_len']: |
|
|
current_spkcache = torch.nn.functional.pad( |
|
|
current_spkcache, (0, 0, 0, config['spkcache_input_len'] - curr_spk_len) |
|
|
) |
|
|
elif curr_spk_len > config['spkcache_input_len']: |
|
|
current_spkcache = current_spkcache[:, :config['spkcache_input_len'], :] |
|
|
|
|
|
current_fifo = state.fifo |
|
|
if curr_fifo_len < config['fifo_input_len']: |
|
|
current_fifo = torch.nn.functional.pad( |
|
|
current_fifo, (0, 0, 0, config['fifo_input_len'] - curr_fifo_len) |
|
|
) |
|
|
elif curr_fifo_len > config['fifo_input_len']: |
|
|
current_fifo = current_fifo[:, :config['fifo_input_len'], :] |
|
|
|
|
|
|
|
|
coreml_inputs = { |
|
|
"chunk": chunk_in.numpy().astype(np.float32), |
|
|
"chunk_lengths": np.array([actual_len], dtype=np.int32), |
|
|
"spkcache": current_spkcache.numpy().astype(np.float32), |
|
|
"spkcache_lengths": np.array([curr_spk_len], dtype=np.int32), |
|
|
"fifo": current_fifo.numpy().astype(np.float32), |
|
|
"fifo_lengths": np.array([curr_fifo_len], dtype=np.int32) |
|
|
} |
|
|
|
|
|
coreml_out = main_model.predict(coreml_inputs) |
|
|
|
|
|
pred_logits = torch.from_numpy(coreml_out["speaker_preds"]) |
|
|
chunk_embs = torch.from_numpy(coreml_out["chunk_pre_encoder_embs"]) |
|
|
chunk_emb_len = int(coreml_out["chunk_pre_encoder_lengths"][0]) |
|
|
|
|
|
chunk_embs = chunk_embs[:, :chunk_emb_len, :] |
|
|
|
|
|
lc = round(left_offset / subsampling_factor) |
|
|
rc = math.ceil(right_offset / subsampling_factor) |
|
|
|
|
|
state, chunk_probs = modules.streaming_update( |
|
|
streaming_state=state, |
|
|
chunk=chunk_embs, |
|
|
preds=pred_logits, |
|
|
lc=lc, |
|
|
rc=rc |
|
|
) |
|
|
|
|
|
all_preds.append(chunk_probs) |
|
|
stt_feat = end_feat |
|
|
chunk_idx += 1 |
|
|
|
|
|
print(f"\r Diarization chunk {chunk_idx}", end='') |
|
|
|
|
|
print() |
|
|
|
|
|
if len(all_preds) > 0: |
|
|
return torch.cat(all_preds, dim=1) |
|
|
return None |
|
|
|
|
|
|
|
|
def run_reference(nemo_model, main_model, audio_path, config): |
|
|
""" |
|
|
Reference implementation using NeMo preprocessing. |
|
|
""" |
|
|
modules = nemo_model.sortformer_modules |
|
|
subsampling_factor = modules.subsampling_factor |
|
|
|
|
|
|
|
|
full_audio, _ = librosa.load(audio_path, sr=config['sample_rate'], mono=True) |
|
|
audio_tensor = torch.from_numpy(full_audio).unsqueeze(0).float() |
|
|
audio_length = torch.tensor([len(full_audio)], dtype=torch.long) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
processed_signal, processed_signal_length = nemo_model.process_signal( |
|
|
audio_signal=audio_tensor, audio_signal_length=audio_length |
|
|
) |
|
|
processed_signal = processed_signal[:, :, :processed_signal_length.max()] |
|
|
|
|
|
print(f"NeMo Preproc: features shape = {processed_signal.shape}") |
|
|
|
|
|
|
|
|
state = modules.init_streaming_state(batch_size=1, device='cpu') |
|
|
all_preds = [] |
|
|
|
|
|
feat_len = processed_signal.shape[2] |
|
|
chunk_len = modules.chunk_len |
|
|
left_ctx = modules.chunk_left_context |
|
|
right_ctx = modules.chunk_right_context |
|
|
|
|
|
stt_feat, end_feat, chunk_idx = 0, 0, 0 |
|
|
|
|
|
while end_feat < feat_len: |
|
|
left_offset = min(left_ctx * subsampling_factor, stt_feat) |
|
|
end_feat = min(stt_feat + chunk_len * subsampling_factor, feat_len) |
|
|
right_offset = min(right_ctx * subsampling_factor, feat_len - end_feat) |
|
|
|
|
|
chunk_feat = processed_signal[:, :, stt_feat - left_offset : end_feat + right_offset] |
|
|
actual_len = chunk_feat.shape[2] |
|
|
|
|
|
chunk_t = chunk_feat.transpose(1, 2) |
|
|
|
|
|
if actual_len < config['chunk_frames']: |
|
|
pad_len = config['chunk_frames'] - actual_len |
|
|
chunk_in = torch.nn.functional.pad(chunk_t, (0, 0, 0, pad_len)) |
|
|
else: |
|
|
chunk_in = chunk_t[:, :config['chunk_frames'], :] |
|
|
|
|
|
curr_spk_len = state.spkcache.shape[1] |
|
|
curr_fifo_len = state.fifo.shape[1] |
|
|
|
|
|
current_spkcache = state.spkcache |
|
|
if curr_spk_len < config['spkcache_input_len']: |
|
|
current_spkcache = torch.nn.functional.pad( |
|
|
current_spkcache, (0, 0, 0, config['spkcache_input_len'] - curr_spk_len) |
|
|
) |
|
|
elif curr_spk_len > config['spkcache_input_len']: |
|
|
current_spkcache = current_spkcache[:, :config['spkcache_input_len'], :] |
|
|
|
|
|
current_fifo = state.fifo |
|
|
if curr_fifo_len < config['fifo_input_len']: |
|
|
current_fifo = torch.nn.functional.pad( |
|
|
current_fifo, (0, 0, 0, config['fifo_input_len'] - curr_fifo_len) |
|
|
) |
|
|
elif curr_fifo_len > config['fifo_input_len']: |
|
|
current_fifo = current_fifo[:, :config['fifo_input_len'], :] |
|
|
|
|
|
coreml_inputs = { |
|
|
"chunk": chunk_in.numpy().astype(np.float32), |
|
|
"chunk_lengths": np.array([actual_len], dtype=np.int32), |
|
|
"spkcache": current_spkcache.numpy().astype(np.float32), |
|
|
"spkcache_lengths": np.array([curr_spk_len], dtype=np.int32), |
|
|
"fifo": current_fifo.numpy().astype(np.float32), |
|
|
"fifo_lengths": np.array([curr_fifo_len], dtype=np.int32) |
|
|
} |
|
|
|
|
|
coreml_out = main_model.predict(coreml_inputs) |
|
|
|
|
|
pred_logits = torch.from_numpy(coreml_out["speaker_preds"]) |
|
|
chunk_embs = torch.from_numpy(coreml_out["chunk_pre_encoder_embs"]) |
|
|
chunk_emb_len = int(coreml_out["chunk_pre_encoder_lengths"][0]) |
|
|
|
|
|
chunk_embs = chunk_embs[:, :chunk_emb_len, :] |
|
|
|
|
|
lc = round(left_offset / subsampling_factor) |
|
|
rc = math.ceil(right_offset / subsampling_factor) |
|
|
|
|
|
state, chunk_probs = modules.streaming_update( |
|
|
streaming_state=state, |
|
|
chunk=chunk_embs, |
|
|
preds=pred_logits, |
|
|
lc=lc, |
|
|
rc=rc |
|
|
) |
|
|
|
|
|
all_preds.append(chunk_probs) |
|
|
stt_feat = end_feat |
|
|
chunk_idx += 1 |
|
|
|
|
|
if len(all_preds) > 0: |
|
|
return torch.cat(all_preds, dim=1) |
|
|
return None |
|
|
|
|
|
|
|
|
def validate(model_name, coreml_dir, audio_path): |
|
|
""" |
|
|
Validate true streaming against NeMo preprocessing. |
|
|
""" |
|
|
print("=" * 70) |
|
|
print("VALIDATION: True Streaming vs NeMo Preprocessing") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
print(f"\nLoading NeMo Model: {model_name}") |
|
|
nemo_model = SortformerEncLabelModel.from_pretrained(model_name, map_location="cpu") |
|
|
nemo_model.eval() |
|
|
|
|
|
|
|
|
modules = nemo_model.sortformer_modules |
|
|
modules.chunk_len = CONFIG['chunk_len'] |
|
|
modules.chunk_right_context = CONFIG['chunk_right_context'] |
|
|
modules.chunk_left_context = CONFIG['chunk_left_context'] |
|
|
modules.fifo_len = CONFIG['fifo_len'] |
|
|
modules.spkcache_len = CONFIG['spkcache_len'] |
|
|
modules.spkcache_update_period = CONFIG['spkcache_update_period'] |
|
|
|
|
|
|
|
|
if hasattr(nemo_model.preprocessor, 'featurizer'): |
|
|
nemo_model.preprocessor.featurizer.dither = 0.0 |
|
|
nemo_model.preprocessor.featurizer.pad_to = 0 |
|
|
|
|
|
print(f"Config: chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, " |
|
|
f"right_ctx={modules.chunk_right_context}") |
|
|
|
|
|
|
|
|
print(f"Loading CoreML Models from {coreml_dir}...") |
|
|
preproc_model = ct.models.MLModel( |
|
|
os.path.join(coreml_dir, "SortformerPreprocessor.mlpackage"), |
|
|
compute_units=ct.ComputeUnit.CPU_ONLY |
|
|
) |
|
|
main_model = ct.models.MLModel( |
|
|
os.path.join(coreml_dir, "Sortformer16.mlpackage"), |
|
|
compute_units=ct.ComputeUnit.CPU_ONLY |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("TEST 1: NeMo Preprocessing + CoreML Inference (Reference)") |
|
|
print("=" * 70) |
|
|
|
|
|
ref_probs = run_reference(nemo_model, main_model, audio_path, CONFIG) |
|
|
if ref_probs is not None: |
|
|
ref_probs_np = ref_probs.squeeze(0).detach().cpu().numpy() |
|
|
print(f"Reference Probs Shape: {ref_probs_np.shape}") |
|
|
else: |
|
|
print("Reference inference failed!") |
|
|
return |
|
|
|
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("TEST 2: True Streaming (Audio → CoreML Preproc → CoreML Main)") |
|
|
print("=" * 70) |
|
|
|
|
|
streaming_probs = run_true_streaming(nemo_model, preproc_model, main_model, audio_path, CONFIG) |
|
|
|
|
|
if streaming_probs is not None: |
|
|
streaming_probs_np = streaming_probs.squeeze(0).detach().cpu().numpy() |
|
|
print(f"Streaming Probs Shape: {streaming_probs_np.shape}") |
|
|
|
|
|
|
|
|
min_len = min(ref_probs_np.shape[0], streaming_probs_np.shape[0]) |
|
|
diff = np.abs(ref_probs_np[:min_len] - streaming_probs_np[:min_len]) |
|
|
print(f"\nLength: ref={ref_probs_np.shape[0]}, streaming={streaming_probs_np.shape[0]}") |
|
|
print(f"Mean Absolute Error: {np.mean(diff):.8f}") |
|
|
print(f"Max Absolute Error: {np.max(diff):.8f}") |
|
|
|
|
|
if np.max(diff) < 0.01: |
|
|
print("\n✅ SUCCESS: True streaming matches reference!") |
|
|
else: |
|
|
print("\n⚠️ Errors exceed tolerance") |
|
|
else: |
|
|
print("True streaming inference produced no output!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model_name", default="nvidia/diar_streaming_sortformer_4spk-v2.1") |
|
|
parser.add_argument("--coreml_dir", default="coreml_models") |
|
|
parser.add_argument("--audio_path", default="audio.wav") |
|
|
args = parser.parse_args() |
|
|
|
|
|
validate(args.model_name, args.coreml_dir, args.audio_path) |
|
|
|