diar-streaming-sortformer-coreml / streaming_preproc_inference.py
alexwengg's picture
Upload 33 files
ed33fd7 verified
"""
True Streaming CoreML Diarization
This script implements true streaming inference:
Audio chunks → CoreML Preprocessor → Feature Buffer → CoreML Main Model → Predictions
Audio is processed incrementally, features are accumulated with proper context handling.
"""
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import torch
import numpy as np
import coremltools as ct
import librosa
import argparse
import math
# Import NeMo for state management (streaming_update) only
from nemo.collections.asr.models import SortformerEncLabelModel
# ============================================================
# Configuration for Sortformer16.mlpackage
# ============================================================
CONFIG = {
'chunk_len': 4, # Diarization chunk length
'chunk_right_context': 1, # Right context chunks
'chunk_left_context': 2, # Left context chunks
'fifo_len': 63,
'spkcache_len': 63,
'spkcache_update_period': 50,
'subsampling_factor': 8,
'sample_rate': 16000,
# Derived values
'chunk_frames': 56, # (4+2+1)*8 = 56 feature frames for CoreML input
'spkcache_input_len': 63,
'fifo_input_len': 63,
# Preprocessor settings
'preproc_audio_samples': 9200, # CoreML preprocessor fixed input size
'mel_window': 400, # 25ms @ 16kHz
'mel_stride': 160, # 10ms @ 16kHz
}
def run_true_streaming(nemo_model, preproc_model, main_model, audio_path, config):
"""
True streaming inference: audio chunks → preproc → main model.
Strategy:
1. Process audio in chunks through CoreML preprocessor
2. Accumulate features
3. When enough features for a diarization chunk (with context), run main model
"""
modules = nemo_model.sortformer_modules
subsampling_factor = config['subsampling_factor']
# Load full audio (simulating microphone input)
full_audio, sr = librosa.load(audio_path, sr=config['sample_rate'], mono=True)
total_samples = len(full_audio)
print(f"Total audio samples: {total_samples}")
# Preprocessing parameters
mel_window = config['mel_window']
mel_stride = config['mel_stride']
preproc_len = config['preproc_audio_samples']
# Audio hop for preprocessor (to avoid overlap in features)
audio_hop = preproc_len - mel_window # 8800 samples
# Feature accumulator
all_features = []
audio_offset = 0
preproc_chunk_idx = 0
# Step 1: Process all audio through preprocessor to get features
print("Step 1: Extracting features via CoreML preprocessor...")
while audio_offset < total_samples:
# Get audio chunk
chunk_end = min(audio_offset + preproc_len, total_samples)
audio_chunk = full_audio[audio_offset:chunk_end]
actual_samples = len(audio_chunk)
# Pad if needed
if actual_samples < preproc_len:
audio_chunk = np.pad(audio_chunk, (0, preproc_len - actual_samples))
# Run preprocessor
preproc_inputs = {
"audio_signal": audio_chunk.reshape(1, -1).astype(np.float32),
"length": np.array([actual_samples], dtype=np.int32)
}
preproc_out = preproc_model.predict(preproc_inputs)
feat_chunk = np.array(preproc_out["features"]) # [1, 128, frames]
feat_len = int(preproc_out["feature_lengths"][0])
# Extract valid features and handle overlap
if preproc_chunk_idx == 0:
# First chunk: keep all
valid_feats = feat_chunk[:, :, :feat_len]
else:
# Subsequent: skip overlap frames
overlap_frames = (mel_window - mel_stride) // mel_stride + 1 # ~2-3 frames
valid_feats = feat_chunk[:, :, overlap_frames:feat_len]
all_features.append(valid_feats)
audio_offset += audio_hop
preproc_chunk_idx += 1
print(f"\r Processed audio chunk {preproc_chunk_idx}, features so far: {sum(f.shape[2] for f in all_features)}", end='')
print()
# Concatenate all features
full_features = np.concatenate(all_features, axis=2) # [1, 128, total_frames]
processed_signal = torch.from_numpy(full_features).float()
processed_signal_length = torch.tensor([full_features.shape[2]], dtype=torch.long)
print(f"Total features extracted: {processed_signal.shape}")
# Step 2: Run diarization streaming loop (same as NeMo reference)
print("Step 2: Running diarization streaming...")
state = modules.init_streaming_state(batch_size=1, device='cpu')
all_preds = []
feat_len = processed_signal.shape[2]
chunk_len = modules.chunk_len
left_ctx = modules.chunk_left_context
right_ctx = modules.chunk_right_context
stt_feat, end_feat, chunk_idx = 0, 0, 0
while end_feat < feat_len:
left_offset = min(left_ctx * subsampling_factor, stt_feat)
end_feat = min(stt_feat + chunk_len * subsampling_factor, feat_len)
right_offset = min(right_ctx * subsampling_factor, feat_len - end_feat)
# Extract chunk with context
chunk_feat = processed_signal[:, :, stt_feat - left_offset : end_feat + right_offset]
actual_len = chunk_feat.shape[2]
# Transpose to [B, T, D]
chunk_t = chunk_feat.transpose(1, 2)
# Pad to fixed size
if actual_len < config['chunk_frames']:
pad_len = config['chunk_frames'] - actual_len
chunk_in = torch.nn.functional.pad(chunk_t, (0, 0, 0, pad_len))
else:
chunk_in = chunk_t[:, :config['chunk_frames'], :]
# State preparation
curr_spk_len = state.spkcache.shape[1]
curr_fifo_len = state.fifo.shape[1]
current_spkcache = state.spkcache
if curr_spk_len < config['spkcache_input_len']:
current_spkcache = torch.nn.functional.pad(
current_spkcache, (0, 0, 0, config['spkcache_input_len'] - curr_spk_len)
)
elif curr_spk_len > config['spkcache_input_len']:
current_spkcache = current_spkcache[:, :config['spkcache_input_len'], :]
current_fifo = state.fifo
if curr_fifo_len < config['fifo_input_len']:
current_fifo = torch.nn.functional.pad(
current_fifo, (0, 0, 0, config['fifo_input_len'] - curr_fifo_len)
)
elif curr_fifo_len > config['fifo_input_len']:
current_fifo = current_fifo[:, :config['fifo_input_len'], :]
# CoreML inference
coreml_inputs = {
"chunk": chunk_in.numpy().astype(np.float32),
"chunk_lengths": np.array([actual_len], dtype=np.int32),
"spkcache": current_spkcache.numpy().astype(np.float32),
"spkcache_lengths": np.array([curr_spk_len], dtype=np.int32),
"fifo": current_fifo.numpy().astype(np.float32),
"fifo_lengths": np.array([curr_fifo_len], dtype=np.int32)
}
coreml_out = main_model.predict(coreml_inputs)
pred_logits = torch.from_numpy(coreml_out["speaker_preds"])
chunk_embs = torch.from_numpy(coreml_out["chunk_pre_encoder_embs"])
chunk_emb_len = int(coreml_out["chunk_pre_encoder_lengths"][0])
chunk_embs = chunk_embs[:, :chunk_emb_len, :]
lc = round(left_offset / subsampling_factor)
rc = math.ceil(right_offset / subsampling_factor)
state, chunk_probs = modules.streaming_update(
streaming_state=state,
chunk=chunk_embs,
preds=pred_logits,
lc=lc,
rc=rc
)
all_preds.append(chunk_probs)
stt_feat = end_feat
chunk_idx += 1
print(f"\r Diarization chunk {chunk_idx}", end='')
print()
if len(all_preds) > 0:
return torch.cat(all_preds, dim=1)
return None
def run_reference(nemo_model, main_model, audio_path, config):
"""
Reference implementation using NeMo preprocessing.
"""
modules = nemo_model.sortformer_modules
subsampling_factor = modules.subsampling_factor
# Load full audio
full_audio, _ = librosa.load(audio_path, sr=config['sample_rate'], mono=True)
audio_tensor = torch.from_numpy(full_audio).unsqueeze(0).float()
audio_length = torch.tensor([len(full_audio)], dtype=torch.long)
# Extract features using NeMo preprocessor
with torch.no_grad():
processed_signal, processed_signal_length = nemo_model.process_signal(
audio_signal=audio_tensor, audio_signal_length=audio_length
)
processed_signal = processed_signal[:, :, :processed_signal_length.max()]
print(f"NeMo Preproc: features shape = {processed_signal.shape}")
# Streaming loop
state = modules.init_streaming_state(batch_size=1, device='cpu')
all_preds = []
feat_len = processed_signal.shape[2]
chunk_len = modules.chunk_len
left_ctx = modules.chunk_left_context
right_ctx = modules.chunk_right_context
stt_feat, end_feat, chunk_idx = 0, 0, 0
while end_feat < feat_len:
left_offset = min(left_ctx * subsampling_factor, stt_feat)
end_feat = min(stt_feat + chunk_len * subsampling_factor, feat_len)
right_offset = min(right_ctx * subsampling_factor, feat_len - end_feat)
chunk_feat = processed_signal[:, :, stt_feat - left_offset : end_feat + right_offset]
actual_len = chunk_feat.shape[2]
chunk_t = chunk_feat.transpose(1, 2)
if actual_len < config['chunk_frames']:
pad_len = config['chunk_frames'] - actual_len
chunk_in = torch.nn.functional.pad(chunk_t, (0, 0, 0, pad_len))
else:
chunk_in = chunk_t[:, :config['chunk_frames'], :]
curr_spk_len = state.spkcache.shape[1]
curr_fifo_len = state.fifo.shape[1]
current_spkcache = state.spkcache
if curr_spk_len < config['spkcache_input_len']:
current_spkcache = torch.nn.functional.pad(
current_spkcache, (0, 0, 0, config['spkcache_input_len'] - curr_spk_len)
)
elif curr_spk_len > config['spkcache_input_len']:
current_spkcache = current_spkcache[:, :config['spkcache_input_len'], :]
current_fifo = state.fifo
if curr_fifo_len < config['fifo_input_len']:
current_fifo = torch.nn.functional.pad(
current_fifo, (0, 0, 0, config['fifo_input_len'] - curr_fifo_len)
)
elif curr_fifo_len > config['fifo_input_len']:
current_fifo = current_fifo[:, :config['fifo_input_len'], :]
coreml_inputs = {
"chunk": chunk_in.numpy().astype(np.float32),
"chunk_lengths": np.array([actual_len], dtype=np.int32),
"spkcache": current_spkcache.numpy().astype(np.float32),
"spkcache_lengths": np.array([curr_spk_len], dtype=np.int32),
"fifo": current_fifo.numpy().astype(np.float32),
"fifo_lengths": np.array([curr_fifo_len], dtype=np.int32)
}
coreml_out = main_model.predict(coreml_inputs)
pred_logits = torch.from_numpy(coreml_out["speaker_preds"])
chunk_embs = torch.from_numpy(coreml_out["chunk_pre_encoder_embs"])
chunk_emb_len = int(coreml_out["chunk_pre_encoder_lengths"][0])
chunk_embs = chunk_embs[:, :chunk_emb_len, :]
lc = round(left_offset / subsampling_factor)
rc = math.ceil(right_offset / subsampling_factor)
state, chunk_probs = modules.streaming_update(
streaming_state=state,
chunk=chunk_embs,
preds=pred_logits,
lc=lc,
rc=rc
)
all_preds.append(chunk_probs)
stt_feat = end_feat
chunk_idx += 1
if len(all_preds) > 0:
return torch.cat(all_preds, dim=1)
return None
def validate(model_name, coreml_dir, audio_path):
"""
Validate true streaming against NeMo preprocessing.
"""
print("=" * 70)
print("VALIDATION: True Streaming vs NeMo Preprocessing")
print("=" * 70)
# Load NeMo model
print(f"\nLoading NeMo Model: {model_name}")
nemo_model = SortformerEncLabelModel.from_pretrained(model_name, map_location="cpu")
nemo_model.eval()
# Apply config
modules = nemo_model.sortformer_modules
modules.chunk_len = CONFIG['chunk_len']
modules.chunk_right_context = CONFIG['chunk_right_context']
modules.chunk_left_context = CONFIG['chunk_left_context']
modules.fifo_len = CONFIG['fifo_len']
modules.spkcache_len = CONFIG['spkcache_len']
modules.spkcache_update_period = CONFIG['spkcache_update_period']
# Disable dither and pad_to
if hasattr(nemo_model.preprocessor, 'featurizer'):
nemo_model.preprocessor.featurizer.dither = 0.0
nemo_model.preprocessor.featurizer.pad_to = 0
print(f"Config: chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, "
f"right_ctx={modules.chunk_right_context}")
# Load CoreML models
print(f"Loading CoreML Models from {coreml_dir}...")
preproc_model = ct.models.MLModel(
os.path.join(coreml_dir, "SortformerPreprocessor.mlpackage"),
compute_units=ct.ComputeUnit.CPU_ONLY
)
main_model = ct.models.MLModel(
os.path.join(coreml_dir, "Sortformer16.mlpackage"),
compute_units=ct.ComputeUnit.CPU_ONLY
)
# Reference
print("\n" + "=" * 70)
print("TEST 1: NeMo Preprocessing + CoreML Inference (Reference)")
print("=" * 70)
ref_probs = run_reference(nemo_model, main_model, audio_path, CONFIG)
if ref_probs is not None:
ref_probs_np = ref_probs.squeeze(0).detach().cpu().numpy()
print(f"Reference Probs Shape: {ref_probs_np.shape}")
else:
print("Reference inference failed!")
return
# True streaming
print("\n" + "=" * 70)
print("TEST 2: True Streaming (Audio → CoreML Preproc → CoreML Main)")
print("=" * 70)
streaming_probs = run_true_streaming(nemo_model, preproc_model, main_model, audio_path, CONFIG)
if streaming_probs is not None:
streaming_probs_np = streaming_probs.squeeze(0).detach().cpu().numpy()
print(f"Streaming Probs Shape: {streaming_probs_np.shape}")
# Compare
min_len = min(ref_probs_np.shape[0], streaming_probs_np.shape[0])
diff = np.abs(ref_probs_np[:min_len] - streaming_probs_np[:min_len])
print(f"\nLength: ref={ref_probs_np.shape[0]}, streaming={streaming_probs_np.shape[0]}")
print(f"Mean Absolute Error: {np.mean(diff):.8f}")
print(f"Max Absolute Error: {np.max(diff):.8f}")
if np.max(diff) < 0.01:
print("\n✅ SUCCESS: True streaming matches reference!")
else:
print("\n⚠️ Errors exceed tolerance")
else:
print("True streaming inference produced no output!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="nvidia/diar_streaming_sortformer_4spk-v2.1")
parser.add_argument("--coreml_dir", default="coreml_models")
parser.add_argument("--audio_path", default="audio.wav")
args = parser.parse_args()
validate(args.model_name, args.coreml_dir, args.audio_path)