Upload 33 files
Browse files- convert.py +82 -0
- convert_dynamic.py +302 -0
- coreml_wrappers.py +215 -0
- mic_inference.py +519 -0
- nemo_streaming_reference.py +153 -0
- streaming_inference.py +262 -0
- streaming_preproc_inference.py +411 -0
convert.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import matplotlib.patches as patches
|
| 4 |
+
import matplotlib
|
| 5 |
+
import seaborn as sns
|
| 6 |
+
import numpy as np
|
| 7 |
+
import threading
|
| 8 |
+
import onnx2torch
|
| 9 |
+
import onnxscript
|
| 10 |
+
from nemo.collections.asr.models import SortformerEncLabelModel
|
| 11 |
+
from pydub import AudioSegment
|
| 12 |
+
import coremltools as ct
|
| 13 |
+
from pydub.playback import play as play_audio
|
| 14 |
+
|
| 15 |
+
# --- 1. Setup & Config ---
|
| 16 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 17 |
+
audio_file = "audio.wav"
|
| 18 |
+
|
| 19 |
+
# Load Audio for Playback (pydub uses milliseconds)
|
| 20 |
+
print("Loading audio file for playback...")
|
| 21 |
+
full_audio = AudioSegment.from_wav(audio_file)
|
| 22 |
+
|
| 23 |
+
# --- 2. Load Model ---
|
| 24 |
+
model = SortformerEncLabelModel.from_pretrained(
|
| 25 |
+
"nvidia/diar_streaming_sortformer_4spk-v2.1",
|
| 26 |
+
map_location=device
|
| 27 |
+
)
|
| 28 |
+
model.eval()
|
| 29 |
+
model.to(device)
|
| 30 |
+
|
| 31 |
+
print(model.output_names)
|
| 32 |
+
|
| 33 |
+
def streaming_input_examples(self):
|
| 34 |
+
"""Input tensor examples for exporting streaming version of model"""
|
| 35 |
+
batch_size = 4
|
| 36 |
+
feat_in = self.cfg.get("preprocessor", {}).get("features", 128)
|
| 37 |
+
chunk = torch.rand([batch_size, 120, feat_in]).to(self.device)
|
| 38 |
+
chunk_lengths = torch.tensor([120] * batch_size).to(self.device)
|
| 39 |
+
spkcache = torch.randn([batch_size, 188, 512]).to(self.device)
|
| 40 |
+
spkcache_lengths = torch.tensor([40, 188, 0, 68]).to(self.device)
|
| 41 |
+
fifo = torch.randn([batch_size, 188, 512]).to(self.device)
|
| 42 |
+
fifo_lengths = torch.tensor([50, 88, 0, 90]).to(self.device)
|
| 43 |
+
return chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
inputs = streaming_input_examples(model)
|
| 47 |
+
|
| 48 |
+
export_out = model.export("streaming-sortformer.onnx", input_example=inputs)
|
| 49 |
+
scripted_model = onnx2torch.convert('streaming-sortformer.onnx')
|
| 50 |
+
|
| 51 |
+
BATCH_SIZE = 4
|
| 52 |
+
CHUNK_LEN = 120
|
| 53 |
+
FEAT_DIM = 128
|
| 54 |
+
CACHE_LEN = 188
|
| 55 |
+
EMBED_DIM = 512
|
| 56 |
+
|
| 57 |
+
ct_inputs = [
|
| 58 |
+
ct.TensorType(name="chunk", shape=(BATCH_SIZE, CHUNK_LEN, FEAT_DIM)),
|
| 59 |
+
ct.TensorType(name="chunk_lens", shape=(BATCH_SIZE,)),
|
| 60 |
+
ct.TensorType(name="spkcache", shape=(BATCH_SIZE, CACHE_LEN, EMBED_DIM)),
|
| 61 |
+
ct.TensorType(name="spkcache_lens", shape=(BATCH_SIZE,)),
|
| 62 |
+
ct.TensorType(name="fifo", shape=(BATCH_SIZE, CACHE_LEN, EMBED_DIM)),
|
| 63 |
+
ct.TensorType(name="fifo_lens", shape=(BATCH_SIZE,)),
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
ct_outputs = [
|
| 67 |
+
ct.TensorType(name="preds"),
|
| 68 |
+
ct.TensorType(name="new_spkcache"),
|
| 69 |
+
ct.TensorType(name="new_spkcache_lens"),
|
| 70 |
+
ct.TensorType(name="new_fifo"),
|
| 71 |
+
ct.TensorType(name="new_fifo_lens"),
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
ct.convert(
|
| 76 |
+
scripted_model,
|
| 77 |
+
inputs=ct_inputs,
|
| 78 |
+
outputs=ct_outputs,
|
| 79 |
+
convert_to="mlprogram",
|
| 80 |
+
minimum_deployment_target=ct.target.iOS17,
|
| 81 |
+
compute_precision=ct.precision.FLOAT16,
|
| 82 |
+
)
|
convert_dynamic.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Convert Sortformer to CoreML with proper dynamic length handling.
|
| 4 |
+
|
| 5 |
+
The key issue: Original conversion traced with fixed lengths (spkcache=120, fifo=40),
|
| 6 |
+
but at runtime we need to handle empty state (spkcache=0, fifo=0) for first chunk.
|
| 7 |
+
|
| 8 |
+
Solution: Use scripting instead of tracing, or trace with multiple example lengths.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import coremltools as ct
|
| 14 |
+
import numpy as np
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
|
| 18 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 19 |
+
sys.path.insert(0, os.path.join(SCRIPT_DIR, 'NeMo'))
|
| 20 |
+
|
| 21 |
+
from nemo.collections.asr.models import SortformerEncLabelModel
|
| 22 |
+
|
| 23 |
+
print("=" * 70)
|
| 24 |
+
print("CONVERTING SORTFORMER WITH DYNAMIC LENGTH SUPPORT")
|
| 25 |
+
print("=" * 70)
|
| 26 |
+
|
| 27 |
+
# Load model
|
| 28 |
+
model_path = os.path.join(SCRIPT_DIR, 'diar_streaming_sortformer_4spk-v2.nemo')
|
| 29 |
+
print(f"Loading model: {model_path}")
|
| 30 |
+
model = SortformerEncLabelModel.restore_from(model_path, map_location='cpu', strict=False)
|
| 31 |
+
model.eval()
|
| 32 |
+
|
| 33 |
+
# Configure for low-latency streaming
|
| 34 |
+
modules = model.sortformer_modules
|
| 35 |
+
modules.chunk_len = 6
|
| 36 |
+
modules.chunk_left_context = 1
|
| 37 |
+
modules.chunk_right_context = 1
|
| 38 |
+
modules.fifo_len = 40
|
| 39 |
+
modules.spkcache_len = 120
|
| 40 |
+
modules.spkcache_update_period = 30
|
| 41 |
+
|
| 42 |
+
print(f"Config: chunk_len={modules.chunk_len}, left={modules.chunk_left_context}, right={modules.chunk_right_context}")
|
| 43 |
+
print(f" fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}")
|
| 44 |
+
|
| 45 |
+
# Dimensions
|
| 46 |
+
chunk_frames = (modules.chunk_len + modules.chunk_left_context + modules.chunk_right_context) * modules.subsampling_factor
|
| 47 |
+
fc_d_model = modules.fc_d_model # 512
|
| 48 |
+
feat_dim = 128
|
| 49 |
+
|
| 50 |
+
print(f"Chunk frames: {chunk_frames}")
|
| 51 |
+
|
| 52 |
+
class DynamicPreEncoderWrapper(nn.Module):
|
| 53 |
+
"""Pre-encoder that properly handles dynamic lengths."""
|
| 54 |
+
|
| 55 |
+
def __init__(self, model, max_spkcache=120, max_fifo=40, max_chunk=8):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.model = model
|
| 58 |
+
self.max_spkcache = max_spkcache
|
| 59 |
+
self.max_fifo = max_fifo
|
| 60 |
+
self.max_chunk = max_chunk
|
| 61 |
+
self.max_total = max_spkcache + max_fifo + max_chunk
|
| 62 |
+
|
| 63 |
+
def forward(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths):
|
| 64 |
+
# Pre-encode the chunk
|
| 65 |
+
chunk_embs, chunk_emb_lengths = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths)
|
| 66 |
+
|
| 67 |
+
# Get actual lengths as scalars
|
| 68 |
+
spk_len = spkcache_lengths[0].item() if spkcache_lengths.numel() > 0 else 0
|
| 69 |
+
fifo_len = fifo_lengths[0].item() if fifo_lengths.numel() > 0 else 0
|
| 70 |
+
chunk_len = chunk_emb_lengths[0].item()
|
| 71 |
+
total_len = spk_len + fifo_len + chunk_len
|
| 72 |
+
|
| 73 |
+
# Create output tensor (packed at start, rest is zeros)
|
| 74 |
+
B, _, D = spkcache.shape
|
| 75 |
+
output = torch.zeros(B, self.max_total, D, device=chunk.device, dtype=chunk.dtype)
|
| 76 |
+
|
| 77 |
+
# Copy valid frames
|
| 78 |
+
if spk_len > 0:
|
| 79 |
+
output[:, :spk_len, :] = spkcache[:, :spk_len, :]
|
| 80 |
+
if fifo_len > 0:
|
| 81 |
+
output[:, spk_len:spk_len+fifo_len, :] = fifo[:, :fifo_len, :]
|
| 82 |
+
output[:, spk_len+fifo_len:spk_len+fifo_len+chunk_len, :] = chunk_embs[:, :chunk_len, :]
|
| 83 |
+
|
| 84 |
+
total_length = torch.tensor([total_len], dtype=torch.long)
|
| 85 |
+
|
| 86 |
+
return output, total_length, chunk_embs, chunk_emb_lengths
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class DynamicHeadWrapper(nn.Module):
|
| 90 |
+
"""Head that properly handles dynamic lengths with masking."""
|
| 91 |
+
|
| 92 |
+
def __init__(self, model):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.model = model
|
| 95 |
+
|
| 96 |
+
def forward(self, pre_encoder_embs, pre_encoder_lengths, chunk_embs, chunk_emb_lengths):
|
| 97 |
+
# Encode
|
| 98 |
+
fc_embs, fc_lengths = self.model.frontend_encoder(
|
| 99 |
+
processed_signal=pre_encoder_embs,
|
| 100 |
+
processed_signal_length=pre_encoder_lengths,
|
| 101 |
+
bypass_pre_encode=True,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Get predictions
|
| 105 |
+
preds = self.model.forward_infer(fc_embs, fc_lengths)
|
| 106 |
+
|
| 107 |
+
# Apply mask based on actual length
|
| 108 |
+
# preds shape: [B, T, num_speakers]
|
| 109 |
+
max_len = preds.shape[1]
|
| 110 |
+
length = pre_encoder_lengths[0]
|
| 111 |
+
mask = torch.arange(max_len, device=preds.device) < length
|
| 112 |
+
preds = preds * mask.unsqueeze(0).unsqueeze(-1).float()
|
| 113 |
+
|
| 114 |
+
return preds, chunk_embs, chunk_emb_lengths
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Test with both empty and full state
|
| 118 |
+
print("\n" + "=" * 70)
|
| 119 |
+
print("TESTING DYNAMIC WRAPPERS")
|
| 120 |
+
print("=" * 70)
|
| 121 |
+
|
| 122 |
+
pre_encoder = DynamicPreEncoderWrapper(model)
|
| 123 |
+
head = DynamicHeadWrapper(model)
|
| 124 |
+
pre_encoder.eval()
|
| 125 |
+
head.eval()
|
| 126 |
+
|
| 127 |
+
# Test 1: Empty state (like chunk 0)
|
| 128 |
+
print("\nTest 1: Empty state (chunk 0)")
|
| 129 |
+
chunk = torch.randn(1, 56, 128) # First chunk has fewer frames
|
| 130 |
+
chunk_len = torch.tensor([56], dtype=torch.long)
|
| 131 |
+
spkcache = torch.zeros(1, 120, 512)
|
| 132 |
+
spkcache_len = torch.tensor([0], dtype=torch.long)
|
| 133 |
+
fifo = torch.zeros(1, 40, 512)
|
| 134 |
+
fifo_len = torch.tensor([0], dtype=torch.long)
|
| 135 |
+
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
pre_out, pre_len, chunk_embs, chunk_emb_len = pre_encoder(
|
| 138 |
+
chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len
|
| 139 |
+
)
|
| 140 |
+
preds, _, _ = head(pre_out, pre_len, chunk_embs, chunk_emb_len)
|
| 141 |
+
|
| 142 |
+
print(f" Pre-encoder output: {pre_out.shape}, length={pre_len.item()}")
|
| 143 |
+
print(f" Chunk embeddings: {chunk_embs.shape}, length={chunk_emb_len.item()}")
|
| 144 |
+
print(f" Predictions: {preds.shape}")
|
| 145 |
+
sums = [f"{preds[0, i, :].sum().item():.4f}" for i in range(min(8, preds.shape[1]))]
|
| 146 |
+
print(f" First 8 pred frames sum: {sums}")
|
| 147 |
+
|
| 148 |
+
# Test 2: Full state
|
| 149 |
+
print("\nTest 2: Full state")
|
| 150 |
+
chunk = torch.randn(1, 64, 128)
|
| 151 |
+
chunk_len = torch.tensor([64], dtype=torch.long)
|
| 152 |
+
spkcache = torch.randn(1, 120, 512)
|
| 153 |
+
spkcache_len = torch.tensor([120], dtype=torch.long)
|
| 154 |
+
fifo = torch.randn(1, 40, 512)
|
| 155 |
+
fifo_len = torch.tensor([40], dtype=torch.long)
|
| 156 |
+
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
pre_out, pre_len, chunk_embs, chunk_emb_len = pre_encoder(
|
| 159 |
+
chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len
|
| 160 |
+
)
|
| 161 |
+
preds, _, _ = head(pre_out, pre_len, chunk_embs, chunk_emb_len)
|
| 162 |
+
|
| 163 |
+
print(f" Pre-encoder output: {pre_out.shape}, length={pre_len.item()}")
|
| 164 |
+
print(f" Chunk embeddings: {chunk_embs.shape}, length={chunk_emb_len.item()}")
|
| 165 |
+
print(f" Predictions: {preds.shape}")
|
| 166 |
+
|
| 167 |
+
print("\n" + "=" * 70)
|
| 168 |
+
print("ISSUE IDENTIFIED")
|
| 169 |
+
print("=" * 70)
|
| 170 |
+
print("""
|
| 171 |
+
The problem is that the current CoreML model was traced with FIXED lengths.
|
| 172 |
+
When lengths change at runtime, the traced operations don't adapt.
|
| 173 |
+
|
| 174 |
+
The fix requires re-tracing with proper dynamic handling OR using coremltools
|
| 175 |
+
flexible shapes feature.
|
| 176 |
+
|
| 177 |
+
For now, let's try a simpler approach: always pad inputs to max size and
|
| 178 |
+
use the length parameters only for extracting the correct output slice.
|
| 179 |
+
""")
|
| 180 |
+
|
| 181 |
+
# The issue is that torch.jit.trace captures specific tensor values
|
| 182 |
+
# We need to use torch.jit.script for truly dynamic behavior
|
| 183 |
+
# But many NeMo operations don't work with script
|
| 184 |
+
|
| 185 |
+
print("\nATTEMPTING CONVERSION WITH FLEXIBLE SHAPES...")
|
| 186 |
+
|
| 187 |
+
# Try using coremltools range shapes
|
| 188 |
+
try:
|
| 189 |
+
# Create wrapper that handles the length masking internally
|
| 190 |
+
class SimplePipelineWrapper(nn.Module):
|
| 191 |
+
def __init__(self, model):
|
| 192 |
+
super().__init__()
|
| 193 |
+
self.model = model
|
| 194 |
+
|
| 195 |
+
def forward(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths):
|
| 196 |
+
# Pre-encode chunk
|
| 197 |
+
chunk_embs, chunk_emb_lens = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths)
|
| 198 |
+
|
| 199 |
+
# Get lengths
|
| 200 |
+
spk_len = spkcache_lengths[0]
|
| 201 |
+
fifo_len = fifo_lengths[0]
|
| 202 |
+
chunk_len = chunk_emb_lens[0]
|
| 203 |
+
|
| 204 |
+
# Concatenate (always use fixed output size, rely on length for valid region)
|
| 205 |
+
# This matches what NeMo does internally
|
| 206 |
+
B = chunk.shape[0]
|
| 207 |
+
max_out = 168 # 120 + 40 + 8
|
| 208 |
+
D = 512
|
| 209 |
+
|
| 210 |
+
concat_embs = torch.zeros(B, max_out, D, device=chunk.device, dtype=chunk.dtype)
|
| 211 |
+
|
| 212 |
+
# Copy spkcache
|
| 213 |
+
for i in range(120):
|
| 214 |
+
if i < spk_len:
|
| 215 |
+
concat_embs[:, i, :] = spkcache[:, i, :]
|
| 216 |
+
|
| 217 |
+
# Copy fifo
|
| 218 |
+
for i in range(40):
|
| 219 |
+
if i < fifo_len:
|
| 220 |
+
concat_embs[:, 120 + i, :] = fifo[:, i, :]
|
| 221 |
+
|
| 222 |
+
# Copy chunk embeddings
|
| 223 |
+
for i in range(8):
|
| 224 |
+
if i < chunk_len:
|
| 225 |
+
concat_embs[:, 120 + 40 + i, :] = chunk_embs[:, i, :]
|
| 226 |
+
|
| 227 |
+
total_len = spk_len + fifo_len + chunk_len
|
| 228 |
+
total_lens = total_len.unsqueeze(0)
|
| 229 |
+
|
| 230 |
+
# Run through encoder
|
| 231 |
+
fc_embs, fc_lens = self.model.frontend_encoder(
|
| 232 |
+
processed_signal=concat_embs,
|
| 233 |
+
processed_signal_length=total_lens,
|
| 234 |
+
bypass_pre_encode=True,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Get predictions
|
| 238 |
+
preds = self.model.forward_infer(fc_embs, fc_lens)
|
| 239 |
+
|
| 240 |
+
return preds, chunk_embs, chunk_emb_lens
|
| 241 |
+
|
| 242 |
+
wrapper = SimplePipelineWrapper(model)
|
| 243 |
+
wrapper.eval()
|
| 244 |
+
|
| 245 |
+
# Trace with empty state example
|
| 246 |
+
print("Tracing with empty state example...")
|
| 247 |
+
chunk = torch.randn(1, 64, 128)
|
| 248 |
+
chunk_len = torch.tensor([56], dtype=torch.long) # Actual length
|
| 249 |
+
spkcache = torch.zeros(1, 120, 512)
|
| 250 |
+
spkcache_len = torch.tensor([0], dtype=torch.long)
|
| 251 |
+
fifo = torch.zeros(1, 40, 512)
|
| 252 |
+
fifo_len = torch.tensor([0], dtype=torch.long)
|
| 253 |
+
|
| 254 |
+
with torch.no_grad():
|
| 255 |
+
traced = torch.jit.trace(wrapper, (chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len))
|
| 256 |
+
|
| 257 |
+
print("Converting to CoreML...")
|
| 258 |
+
mlmodel = ct.convert(
|
| 259 |
+
traced,
|
| 260 |
+
inputs=[
|
| 261 |
+
ct.TensorType(name="chunk", shape=(1, 64, 128), dtype=np.float32),
|
| 262 |
+
ct.TensorType(name="chunk_lengths", shape=(1,), dtype=np.int32),
|
| 263 |
+
ct.TensorType(name="spkcache", shape=(1, 120, 512), dtype=np.float32),
|
| 264 |
+
ct.TensorType(name="spkcache_lengths", shape=(1,), dtype=np.int32),
|
| 265 |
+
ct.TensorType(name="fifo", shape=(1, 40, 512), dtype=np.float32),
|
| 266 |
+
ct.TensorType(name="fifo_lengths", shape=(1,), dtype=np.int32),
|
| 267 |
+
],
|
| 268 |
+
outputs=[
|
| 269 |
+
ct.TensorType(name="speaker_preds", dtype=np.float32),
|
| 270 |
+
ct.TensorType(name="chunk_pre_encoder_embs", dtype=np.float32),
|
| 271 |
+
ct.TensorType(name="chunk_pre_encoder_lengths", dtype=np.int32),
|
| 272 |
+
],
|
| 273 |
+
minimum_deployment_target=ct.target.iOS16,
|
| 274 |
+
compute_precision=ct.precision.FLOAT32,
|
| 275 |
+
compute_units=ct.ComputeUnit.CPU_ONLY, # Start with CPU for debugging
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
output_path = os.path.join(SCRIPT_DIR, 'coreml_models', 'SortformerPipeline_Dynamic.mlpackage')
|
| 279 |
+
mlmodel.save(output_path)
|
| 280 |
+
print(f"Saved to: {output_path}")
|
| 281 |
+
|
| 282 |
+
# Test the new model
|
| 283 |
+
print("\nTesting new CoreML model...")
|
| 284 |
+
test_output = mlmodel.predict({
|
| 285 |
+
'chunk': chunk.numpy(),
|
| 286 |
+
'chunk_lengths': chunk_len.numpy().astype(np.int32),
|
| 287 |
+
'spkcache': spkcache.numpy(),
|
| 288 |
+
'spkcache_lengths': spkcache_len.numpy().astype(np.int32),
|
| 289 |
+
'fifo': fifo.numpy(),
|
| 290 |
+
'fifo_lengths': fifo_len.numpy().astype(np.int32),
|
| 291 |
+
})
|
| 292 |
+
|
| 293 |
+
coreml_preds = np.array(test_output['speaker_preds'])
|
| 294 |
+
print(f"CoreML predictions shape: {coreml_preds.shape}")
|
| 295 |
+
print(f"CoreML first 8 frames:")
|
| 296 |
+
for i in range(min(8, coreml_preds.shape[1])):
|
| 297 |
+
print(f" Frame {i}: {coreml_preds[0, i, :]}")
|
| 298 |
+
|
| 299 |
+
except Exception as e:
|
| 300 |
+
print(f"Error during conversion: {e}")
|
| 301 |
+
import traceback
|
| 302 |
+
traceback.print_exc()
|
coreml_wrappers.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from safe_concat import *
|
| 4 |
+
from nemo.collections.asr.models import SortformerEncLabelModel
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def fixed_concat_and_pad(embs, lengths, max_total_len=188+188+6):
|
| 8 |
+
"""
|
| 9 |
+
ANE-safe concat and pad that avoids zero-length slices.
|
| 10 |
+
|
| 11 |
+
Uses gather with arithmetic-computed indices to pack valid frames efficiently.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
embs: List of 3 tensors [spkcache, fifo, chunk], each (B, seq_len, D)
|
| 15 |
+
lengths: List of 3 length tensors, each (1,) or scalar
|
| 16 |
+
First two may be 0, third is always > 0
|
| 17 |
+
max_total_len: Output sequence length (padded with zeros)
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
output: (B, max_total_len, D) with valid frames packed at the start
|
| 21 |
+
total_length: sum of lengths
|
| 22 |
+
"""
|
| 23 |
+
B, _, D = embs[0].shape
|
| 24 |
+
device = embs[0].device
|
| 25 |
+
|
| 26 |
+
# Fixed sizes (known at trace time, becomes constants in graph)
|
| 27 |
+
size0, size1, size2 = embs[0].shape[1], embs[1].shape[1], embs[2].shape[1]
|
| 28 |
+
total_input_size = size0 + size1 + size2
|
| 29 |
+
|
| 30 |
+
# Concatenate all embeddings at full size (no zero-length slices!)
|
| 31 |
+
full_concat = torch.cat(embs, dim=1) # (B, total_input_size, D)
|
| 32 |
+
|
| 33 |
+
# Get lengths (reshape to scalar for efficient broadcast)
|
| 34 |
+
len0 = lengths[0].reshape(())
|
| 35 |
+
len1 = lengths[1].reshape(())
|
| 36 |
+
len2 = lengths[2].reshape(())
|
| 37 |
+
total_length = len0 + len1 + len2
|
| 38 |
+
|
| 39 |
+
# Output positions: [0, 1, 2, ..., max_total_len-1]
|
| 40 |
+
out_pos = torch.arange(max_total_len, device=device, dtype=torch.long)
|
| 41 |
+
|
| 42 |
+
# Compute gather indices using arithmetic (more efficient than multiple where())
|
| 43 |
+
#
|
| 44 |
+
# For output position p:
|
| 45 |
+
# seg0 (p < len0): index = p
|
| 46 |
+
# seg1 (len0 <= p < len0+len1): index = (p - len0) + size0 = p + (size0 - len0)
|
| 47 |
+
# seg2 (len0+len1 <= p < total): index = (p - len0 - len1) + size0 + size1
|
| 48 |
+
# = p + (size0 + size1 - len0 - len1)
|
| 49 |
+
#
|
| 50 |
+
# This simplifies to: index = p + offset, where offset depends on segment.
|
| 51 |
+
# offset_seg0 = 0
|
| 52 |
+
# offset_seg1 = size0 - len0
|
| 53 |
+
# offset_seg2 = size0 + size1 - len0 - len1 = offset_seg1 + (size1 - len1)
|
| 54 |
+
#
|
| 55 |
+
# Using segment indicators (0 or 1):
|
| 56 |
+
# offset = in_seg1_or_2 * (size0 - len0) + in_seg2 * (size1 - len1)
|
| 57 |
+
|
| 58 |
+
cumsum0 = len0
|
| 59 |
+
cumsum1 = len0 + len1
|
| 60 |
+
|
| 61 |
+
# Segment indicators (bool -> long for arithmetic)
|
| 62 |
+
in_seg1_or_2 = (out_pos >= cumsum0).long() # 1 if in seg1 or seg2
|
| 63 |
+
in_seg2 = (out_pos >= cumsum1).long() # 1 if in seg2
|
| 64 |
+
|
| 65 |
+
# Compute offset and gather index
|
| 66 |
+
offset = in_seg1_or_2 * (size0 - len0) + in_seg2 * (size1 - len1)
|
| 67 |
+
gather_idx = (out_pos + offset).clamp(0, total_input_size - 1)
|
| 68 |
+
|
| 69 |
+
# Expand for gather: (B, max_total_len, D)
|
| 70 |
+
gather_idx = gather_idx.unsqueeze(0).unsqueeze(-1).expand(B, max_total_len, D)
|
| 71 |
+
|
| 72 |
+
# Gather and mask padding
|
| 73 |
+
output = torch.gather(full_concat, dim=1, index=gather_idx)
|
| 74 |
+
output = output * (out_pos < total_length).float().unsqueeze(0).unsqueeze(-1)
|
| 75 |
+
|
| 76 |
+
return output, total_length
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class PreprocessorWrapper(nn.Module):
|
| 80 |
+
"""
|
| 81 |
+
Wraps the NeMo preprocessor (FilterbankFeaturesTA) for CoreML export.
|
| 82 |
+
We need to ensure it takes (audio, length) and returns (features, length).
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(self, preprocessor):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.preprocessor = preprocessor
|
| 88 |
+
|
| 89 |
+
def forward(self, audio_signal, length):
|
| 90 |
+
# NeMo preprocessor returns (features, length)
|
| 91 |
+
# features shape: [B, D, T]
|
| 92 |
+
return self.preprocessor(input_signal=audio_signal, length=length)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class SortformerHeadWrapper(nn.Module):
|
| 96 |
+
def __init__(self, model):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.model = model
|
| 99 |
+
|
| 100 |
+
def forward(self, pre_encoder_embs, pre_encoder_lengths, chunk_pre_encoder_embs, chunk_pre_encoder_lengths):
|
| 101 |
+
spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths = self.model.frontend_encoder(
|
| 102 |
+
processed_signal=pre_encoder_embs,
|
| 103 |
+
processed_signal_length=pre_encoder_lengths,
|
| 104 |
+
bypass_pre_encode=True,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# forward pass for inference
|
| 108 |
+
spkcache_fifo_chunk_preds = self.model.forward_infer(
|
| 109 |
+
spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths
|
| 110 |
+
)
|
| 111 |
+
return spkcache_fifo_chunk_preds, chunk_pre_encoder_embs, chunk_pre_encoder_lengths
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class SortformerCoreMLWrapper(nn.Module):
|
| 115 |
+
"""
|
| 116 |
+
Wraps the entire Sortformer pipeline (Encoder + Streaming Logic for Export)
|
| 117 |
+
The 'forward_for_export' method in the model is the target.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, model):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.model = model
|
| 123 |
+
self.pre_encoder = PreEncoderWrapper(model)
|
| 124 |
+
|
| 125 |
+
def forward(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths):
|
| 126 |
+
(spkcache_fifo_chunk_pre_encode_embs, spkcache_fifo_chunk_pre_encode_lengths,
|
| 127 |
+
chunk_pre_encode_embs, chunk_pre_encode_lengths) = self.pre_encoder(
|
| 128 |
+
chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# encode the concatenated embeddings
|
| 132 |
+
spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths = self.model.frontend_encoder(
|
| 133 |
+
processed_signal=spkcache_fifo_chunk_pre_encode_embs,
|
| 134 |
+
processed_signal_length=spkcache_fifo_chunk_pre_encode_lengths,
|
| 135 |
+
bypass_pre_encode=True,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# forward pass for inference
|
| 139 |
+
spkcache_fifo_chunk_preds = self.model.forward_infer(
|
| 140 |
+
spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths
|
| 141 |
+
)
|
| 142 |
+
return spkcache_fifo_chunk_preds, chunk_pre_encode_embs, chunk_pre_encode_lengths
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class PreEncoderWrapper(nn.Module):
|
| 146 |
+
"""
|
| 147 |
+
Wraps the entire Sortformer pipeline (Encoder + Streaming Logic for Export)
|
| 148 |
+
The 'forward_for_export' method in the model is the target.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
def __init__(self, model):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.model = model
|
| 154 |
+
modules = model.sortformer_modules
|
| 155 |
+
chunk_length = modules.chunk_left_context + modules.chunk_len + modules.chunk_right_context
|
| 156 |
+
self.pre_encoder_length = modules.spkcache_len + modules.fifo_len + chunk_length
|
| 157 |
+
|
| 158 |
+
def forward(self, *args):
|
| 159 |
+
if len(args) == 6:
|
| 160 |
+
return self.forward_concat(*args)
|
| 161 |
+
else:
|
| 162 |
+
return self.forward_pre_encode(*args)
|
| 163 |
+
|
| 164 |
+
def forward_concat(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths):
|
| 165 |
+
chunk_pre_encode_embs, chunk_pre_encode_lengths = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths)
|
| 166 |
+
chunk_pre_encode_lengths = chunk_pre_encode_lengths.to(torch.int64)
|
| 167 |
+
spkcache_fifo_chunk_pre_encode_embs, spkcache_fifo_chunk_pre_encode_lengths = fixed_concat_and_pad(
|
| 168 |
+
[spkcache, fifo, chunk_pre_encode_embs],
|
| 169 |
+
[spkcache_lengths, fifo_lengths, chunk_pre_encode_lengths],
|
| 170 |
+
self.pre_encoder_length
|
| 171 |
+
)
|
| 172 |
+
return (spkcache_fifo_chunk_pre_encode_embs, spkcache_fifo_chunk_pre_encode_lengths,
|
| 173 |
+
chunk_pre_encode_embs, chunk_pre_encode_lengths)
|
| 174 |
+
|
| 175 |
+
def forward_pre_encode(self, chunk, chunk_lengths):
|
| 176 |
+
chunk_pre_encode_embs, chunk_pre_encode_lengths = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths)
|
| 177 |
+
chunk_pre_encode_lengths = chunk_pre_encode_lengths.to(torch.int64)
|
| 178 |
+
|
| 179 |
+
return chunk_pre_encode_embs, chunk_pre_encode_lengths
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class ConformerEncoderWrapper(nn.Module):
|
| 183 |
+
"""
|
| 184 |
+
Wraps the entire Sortformer pipeline (Encoder + Streaming Logic for Export)
|
| 185 |
+
The 'forward_for_export' method in the model is the target.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
def __init__(self, model):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.model = model
|
| 191 |
+
|
| 192 |
+
def forward(self, pre_encode_embs, pre_encode_lengths):
|
| 193 |
+
spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths = self.model.frontend_encoder(
|
| 194 |
+
processed_signal=pre_encode_embs,
|
| 195 |
+
processed_signal_length=pre_encode_lengths,
|
| 196 |
+
bypass_pre_encode=True,
|
| 197 |
+
)
|
| 198 |
+
return spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class SortformerEncoderWrapper(nn.Module):
|
| 202 |
+
"""
|
| 203 |
+
Wraps the entire Sortformer pipeline (Encoder + Streaming Logic for Export)
|
| 204 |
+
The 'forward_for_export' method in the model is the target.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
def __init__(self, model):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.model = model
|
| 210 |
+
|
| 211 |
+
def forward(self, encoder_embs, encoder_lengths):
|
| 212 |
+
spkcache_fifo_chunk_preds = self.model.forward_infer(
|
| 213 |
+
encoder_embs, encoder_lengths
|
| 214 |
+
)
|
| 215 |
+
return spkcache_fifo_chunk_preds
|
mic_inference.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Real-Time Microphone Diarization with CoreML
|
| 3 |
+
|
| 4 |
+
This script captures audio from the microphone in real-time,
|
| 5 |
+
processes it through CoreML models, and displays a live updating
|
| 6 |
+
diarization heatmap.
|
| 7 |
+
|
| 8 |
+
Pipeline: Microphone → Audio Buffer → CoreML Preproc → CoreML Main → Live Plot
|
| 9 |
+
|
| 10 |
+
Requirements:
|
| 11 |
+
pip install pyaudio matplotlib seaborn numpy coremltools
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python mic_inference.py
|
| 15 |
+
"""
|
| 16 |
+
import os
|
| 17 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
import coremltools as ct
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
import matplotlib
|
| 24 |
+
matplotlib.use('TkAgg')
|
| 25 |
+
import seaborn as sns
|
| 26 |
+
import threading
|
| 27 |
+
import queue
|
| 28 |
+
import time
|
| 29 |
+
import math
|
| 30 |
+
import argparse
|
| 31 |
+
|
| 32 |
+
# Import NeMo for state management
|
| 33 |
+
from nemo.collections.asr.models import SortformerEncLabelModel
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
import sounddevice as sd
|
| 37 |
+
SOUNDDEVICE_AVAILABLE = True
|
| 38 |
+
except ImportError:
|
| 39 |
+
import sounddevice as sd
|
| 40 |
+
SOUNDDEVICE_AVAILABLE = False
|
| 41 |
+
print("Warning: sounddevice not available. Install with: pip install sounddevice")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ============================================================
|
| 45 |
+
# Configuration
|
| 46 |
+
# ============================================================
|
| 47 |
+
CONFIG = {
|
| 48 |
+
'chunk_len': 6,
|
| 49 |
+
'chunk_right_context': 1,
|
| 50 |
+
'chunk_left_context': 1,
|
| 51 |
+
'fifo_len': 40,
|
| 52 |
+
'spkcache_len': 120,
|
| 53 |
+
'spkcache_update_period': 32,
|
| 54 |
+
'subsampling_factor': 8,
|
| 55 |
+
'sample_rate': 16000,
|
| 56 |
+
'mel_window': 400,
|
| 57 |
+
'mel_stride': 160,
|
| 58 |
+
|
| 59 |
+
# Audio settings
|
| 60 |
+
'audio_chunk_samples': 1280, # 80ms chunks from mic
|
| 61 |
+
'channels': 1,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
CONFIG['spkcache_input_len'] = CONFIG['spkcache_len']
|
| 65 |
+
CONFIG['fifo_input_len'] = CONFIG['fifo_len']
|
| 66 |
+
CONFIG['chunk_frames'] = (CONFIG['chunk_len'] + CONFIG['chunk_left_context'] + CONFIG['chunk_right_context']) * CONFIG['subsampling_factor']
|
| 67 |
+
CONFIG['preproc_audio_samples'] = (CONFIG['chunk_frames'] - 1) * CONFIG['mel_stride'] + CONFIG['mel_window']
|
| 68 |
+
|
| 69 |
+
class MicrophoneStream:
|
| 70 |
+
"""Captures audio from microphone using sounddevice."""
|
| 71 |
+
|
| 72 |
+
def __init__(self, sample_rate, chunk_size, audio_queue):
|
| 73 |
+
self.sample_rate = sample_rate
|
| 74 |
+
self.chunk_size = chunk_size
|
| 75 |
+
self.audio_queue = audio_queue
|
| 76 |
+
self.stream = None
|
| 77 |
+
self.running = False
|
| 78 |
+
|
| 79 |
+
def start(self):
|
| 80 |
+
if not SOUNDDEVICE_AVAILABLE:
|
| 81 |
+
print("sounddevice not available!")
|
| 82 |
+
return False
|
| 83 |
+
|
| 84 |
+
def callback(indata, frames, time_info, status):
|
| 85 |
+
if status:
|
| 86 |
+
print(f"Audio status: {status}")
|
| 87 |
+
# indata is already float32 in range [-1, 1]
|
| 88 |
+
audio = indata[:, 0].copy() # Take first channel
|
| 89 |
+
self.audio_queue.put(audio)
|
| 90 |
+
|
| 91 |
+
self.stream = sd.InputStream(
|
| 92 |
+
samplerate=self.sample_rate,
|
| 93 |
+
channels=1,
|
| 94 |
+
dtype=np.float32,
|
| 95 |
+
blocksize=self.chunk_size,
|
| 96 |
+
callback=callback
|
| 97 |
+
)
|
| 98 |
+
self.stream.start()
|
| 99 |
+
self.running = True
|
| 100 |
+
print("Microphone started...")
|
| 101 |
+
return True
|
| 102 |
+
|
| 103 |
+
def stop(self):
|
| 104 |
+
self.running = False
|
| 105 |
+
if self.stream:
|
| 106 |
+
self.stream.stop()
|
| 107 |
+
self.stream.close()
|
| 108 |
+
print("Microphone stopped.")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class StreamingDiarizer:
|
| 112 |
+
"""Real-time streaming diarization using CoreML."""
|
| 113 |
+
|
| 114 |
+
def __init__(self, nemo_model, preproc_model, main_model, config):
|
| 115 |
+
self.modules = nemo_model.sortformer_modules
|
| 116 |
+
self.preproc_model = preproc_model
|
| 117 |
+
self.main_model = main_model
|
| 118 |
+
self.config = config
|
| 119 |
+
|
| 120 |
+
# Audio buffer
|
| 121 |
+
self.audio_buffer = np.array([], dtype=np.float32)
|
| 122 |
+
|
| 123 |
+
# Feature buffer
|
| 124 |
+
self.feature_buffer = None
|
| 125 |
+
self.features_processed = 0
|
| 126 |
+
|
| 127 |
+
# Diarization state
|
| 128 |
+
self.state = self.modules.init_streaming_state(batch_size=1, device='cpu')
|
| 129 |
+
self.all_probs = [] # List of [T, 4] arrays
|
| 130 |
+
|
| 131 |
+
# Chunk tracking
|
| 132 |
+
self.diar_chunk_idx = 0
|
| 133 |
+
self.preproc_chunk_idx = 0
|
| 134 |
+
|
| 135 |
+
# Derived params
|
| 136 |
+
self.subsampling = config['subsampling_factor']
|
| 137 |
+
self.core_frames = config['chunk_len'] * self.subsampling
|
| 138 |
+
self.left_ctx = config['chunk_left_context'] * self.subsampling
|
| 139 |
+
self.right_ctx = config['chunk_right_context'] * self.subsampling
|
| 140 |
+
|
| 141 |
+
# Audio hop for preprocessor
|
| 142 |
+
self.audio_hop = config['preproc_audio_samples'] - config['mel_window']
|
| 143 |
+
self.overlap_frames = (config['mel_window'] - config['mel_stride']) // config['mel_stride'] + 1
|
| 144 |
+
|
| 145 |
+
def add_audio(self, audio_chunk):
|
| 146 |
+
"""Add new audio samples."""
|
| 147 |
+
self.audio_buffer = np.concatenate([self.audio_buffer, audio_chunk])
|
| 148 |
+
|
| 149 |
+
def process(self):
|
| 150 |
+
"""
|
| 151 |
+
Process available audio through preprocessor and diarizer.
|
| 152 |
+
Returns new probability frames if available.
|
| 153 |
+
"""
|
| 154 |
+
new_probs = None
|
| 155 |
+
|
| 156 |
+
# Step 1: Run preprocessor on available audio
|
| 157 |
+
while len(self.audio_buffer) >= self.config['preproc_audio_samples']:
|
| 158 |
+
audio_chunk = self.audio_buffer[:self.config['preproc_audio_samples']]
|
| 159 |
+
|
| 160 |
+
preproc_inputs = {
|
| 161 |
+
"audio_signal": audio_chunk.reshape(1, -1).astype(np.float32),
|
| 162 |
+
"length": np.array([self.config['preproc_audio_samples']], dtype=np.int32)
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
preproc_out = self.preproc_model.predict(preproc_inputs)
|
| 166 |
+
feat_chunk = np.array(preproc_out["features"])
|
| 167 |
+
feat_len = int(preproc_out["feature_lengths"][0])
|
| 168 |
+
|
| 169 |
+
if self.preproc_chunk_idx == 0:
|
| 170 |
+
valid_feats = feat_chunk[:, :, :feat_len]
|
| 171 |
+
else:
|
| 172 |
+
valid_feats = feat_chunk[:, :, self.overlap_frames:feat_len]
|
| 173 |
+
|
| 174 |
+
if self.feature_buffer is None:
|
| 175 |
+
self.feature_buffer = valid_feats
|
| 176 |
+
else:
|
| 177 |
+
self.feature_buffer = np.concatenate([self.feature_buffer, valid_feats], axis=2)
|
| 178 |
+
|
| 179 |
+
self.audio_buffer = self.audio_buffer[self.audio_hop:]
|
| 180 |
+
self.preproc_chunk_idx += 1
|
| 181 |
+
|
| 182 |
+
if self.feature_buffer is None:
|
| 183 |
+
return None
|
| 184 |
+
|
| 185 |
+
# Step 2: Run diarization on available features
|
| 186 |
+
total_features = self.feature_buffer.shape[2]
|
| 187 |
+
|
| 188 |
+
while True:
|
| 189 |
+
# Calculate chunk boundaries
|
| 190 |
+
chunk_start = self.diar_chunk_idx * self.core_frames
|
| 191 |
+
chunk_end = chunk_start + self.core_frames
|
| 192 |
+
|
| 193 |
+
# Need right context
|
| 194 |
+
required_features = chunk_end + self.right_ctx
|
| 195 |
+
|
| 196 |
+
if required_features > total_features:
|
| 197 |
+
break # Not enough features yet
|
| 198 |
+
|
| 199 |
+
# Extract with context
|
| 200 |
+
left_offset = min(self.left_ctx, chunk_start)
|
| 201 |
+
right_offset = min(self.right_ctx, total_features - chunk_end)
|
| 202 |
+
|
| 203 |
+
feat_start = chunk_start - left_offset
|
| 204 |
+
feat_end = chunk_end + right_offset
|
| 205 |
+
|
| 206 |
+
chunk_feat = self.feature_buffer[:, :, feat_start:feat_end]
|
| 207 |
+
chunk_feat_tensor = torch.from_numpy(chunk_feat).float()
|
| 208 |
+
actual_len = chunk_feat.shape[2]
|
| 209 |
+
|
| 210 |
+
# Transpose to [B, T, D]
|
| 211 |
+
chunk_t = chunk_feat_tensor.transpose(1, 2)
|
| 212 |
+
|
| 213 |
+
# Pad if needed
|
| 214 |
+
if actual_len < self.config['chunk_frames']:
|
| 215 |
+
pad_len = self.config['chunk_frames'] - actual_len
|
| 216 |
+
chunk_in = torch.nn.functional.pad(chunk_t, (0, 0, 0, pad_len))
|
| 217 |
+
else:
|
| 218 |
+
chunk_in = chunk_t[:, :self.config['chunk_frames'], :]
|
| 219 |
+
|
| 220 |
+
# State preparation
|
| 221 |
+
curr_spk_len = self.state.spkcache.shape[1]
|
| 222 |
+
curr_fifo_len = self.state.fifo.shape[1]
|
| 223 |
+
|
| 224 |
+
current_spkcache = self.state.spkcache
|
| 225 |
+
if curr_spk_len < self.config['spkcache_input_len']:
|
| 226 |
+
current_spkcache = torch.nn.functional.pad(
|
| 227 |
+
current_spkcache, (0, 0, 0, self.config['spkcache_input_len'] - curr_spk_len)
|
| 228 |
+
)
|
| 229 |
+
elif curr_spk_len > self.config['spkcache_input_len']:
|
| 230 |
+
current_spkcache = current_spkcache[:, :self.config['spkcache_input_len'], :]
|
| 231 |
+
|
| 232 |
+
current_fifo = self.state.fifo
|
| 233 |
+
if curr_fifo_len < self.config['fifo_input_len']:
|
| 234 |
+
current_fifo = torch.nn.functional.pad(
|
| 235 |
+
current_fifo, (0, 0, 0, self.config['fifo_input_len'] - curr_fifo_len)
|
| 236 |
+
)
|
| 237 |
+
elif curr_fifo_len > self.config['fifo_input_len']:
|
| 238 |
+
current_fifo = current_fifo[:, :self.config['fifo_input_len'], :]
|
| 239 |
+
|
| 240 |
+
# CoreML inference
|
| 241 |
+
coreml_inputs = {
|
| 242 |
+
"chunk": chunk_in.numpy().astype(np.float32),
|
| 243 |
+
"chunk_lengths": np.array([actual_len], dtype=np.int32),
|
| 244 |
+
"spkcache": current_spkcache.numpy().astype(np.float32),
|
| 245 |
+
"spkcache_lengths": np.array([curr_spk_len], dtype=np.int32),
|
| 246 |
+
"fifo": current_fifo.numpy().astype(np.float32),
|
| 247 |
+
"fifo_lengths": np.array([curr_fifo_len], dtype=np.int32)
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
st_time = time.time_ns()
|
| 251 |
+
coreml_out = self.main_model.predict(coreml_inputs)
|
| 252 |
+
ed_time = time.time_ns()
|
| 253 |
+
print(f"duration: {1e-6 * (ed_time - st_time)}")
|
| 254 |
+
|
| 255 |
+
pred_logits = torch.from_numpy(coreml_out["speaker_preds"])
|
| 256 |
+
chunk_embs = torch.from_numpy(coreml_out["chunk_pre_encoder_embs"])
|
| 257 |
+
chunk_emb_len = int(coreml_out["chunk_pre_encoder_lengths"][0])
|
| 258 |
+
|
| 259 |
+
chunk_embs = chunk_embs[:, :chunk_emb_len, :]
|
| 260 |
+
|
| 261 |
+
lc = round(left_offset / self.subsampling)
|
| 262 |
+
rc = math.ceil(right_offset / self.subsampling)
|
| 263 |
+
|
| 264 |
+
self.state, chunk_probs = self.modules.streaming_update(
|
| 265 |
+
streaming_state=self.state,
|
| 266 |
+
chunk=chunk_embs,
|
| 267 |
+
preds=pred_logits,
|
| 268 |
+
lc=lc,
|
| 269 |
+
rc=rc
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Store probabilities
|
| 273 |
+
probs_np = chunk_probs.squeeze(0).detach().cpu().numpy()
|
| 274 |
+
self.all_probs.append(probs_np)
|
| 275 |
+
|
| 276 |
+
new_probs = probs_np
|
| 277 |
+
self.diar_chunk_idx += 1
|
| 278 |
+
|
| 279 |
+
return new_probs
|
| 280 |
+
|
| 281 |
+
def get_all_probs(self):
|
| 282 |
+
"""Get all accumulated probabilities."""
|
| 283 |
+
if len(self.all_probs) > 0:
|
| 284 |
+
return np.concatenate(self.all_probs, axis=0)
|
| 285 |
+
return None
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def run_mic_inference(model_name, coreml_dir):
|
| 289 |
+
"""Run real-time microphone diarization."""
|
| 290 |
+
|
| 291 |
+
if not SOUNDDEVICE_AVAILABLE:
|
| 292 |
+
print("Cannot run mic inference without sounddevice!")
|
| 293 |
+
return
|
| 294 |
+
|
| 295 |
+
print("=" * 70)
|
| 296 |
+
print("Real-Time Microphone Diarization")
|
| 297 |
+
print("=" * 70)
|
| 298 |
+
|
| 299 |
+
# Load NeMo model
|
| 300 |
+
print(f"\nLoading NeMo Model: {model_name}")
|
| 301 |
+
nemo_model = SortformerEncLabelModel.from_pretrained(model_name, map_location="cpu")
|
| 302 |
+
nemo_model.eval()
|
| 303 |
+
|
| 304 |
+
# Configure
|
| 305 |
+
modules = nemo_model.sortformer_modules
|
| 306 |
+
modules.chunk_len = CONFIG['chunk_len']
|
| 307 |
+
modules.chunk_right_context = CONFIG['chunk_right_context']
|
| 308 |
+
modules.chunk_left_context = CONFIG['chunk_left_context']
|
| 309 |
+
modules.fifo_len = CONFIG['fifo_len']
|
| 310 |
+
modules.spkcache_len = CONFIG['spkcache_len']
|
| 311 |
+
modules.spkcache_update_period = CONFIG['spkcache_update_period']
|
| 312 |
+
|
| 313 |
+
if hasattr(nemo_model.preprocessor, 'featurizer'):
|
| 314 |
+
nemo_model.preprocessor.featurizer.dither = 0.0
|
| 315 |
+
nemo_model.preprocessor.featurizer.pad_to = 0
|
| 316 |
+
|
| 317 |
+
# Load CoreML models
|
| 318 |
+
print(f"Loading CoreML Models from {coreml_dir}...")
|
| 319 |
+
preproc_model = ct.models.MLModel(
|
| 320 |
+
os.path.join(coreml_dir, "Pipeline_Preprocessor.mlpackage"),
|
| 321 |
+
compute_units=ct.ComputeUnit.CPU_ONLY
|
| 322 |
+
)
|
| 323 |
+
main_model = ct.models.MLModel(
|
| 324 |
+
os.path.join(coreml_dir, "SortformerPipeline.mlpackage"),
|
| 325 |
+
compute_units=ct.ComputeUnit.ALL
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# Create diarizer
|
| 329 |
+
diarizer = StreamingDiarizer(nemo_model, preproc_model, main_model, CONFIG)
|
| 330 |
+
|
| 331 |
+
# Audio queue
|
| 332 |
+
audio_queue = queue.Queue()
|
| 333 |
+
|
| 334 |
+
# Start microphone
|
| 335 |
+
mic = MicrophoneStream(
|
| 336 |
+
sample_rate=CONFIG['sample_rate'],
|
| 337 |
+
chunk_size=CONFIG['audio_chunk_samples'],
|
| 338 |
+
audio_queue=audio_queue
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
if not mic.start():
|
| 342 |
+
return
|
| 343 |
+
|
| 344 |
+
# Setup plot
|
| 345 |
+
plt.ion()
|
| 346 |
+
fig, ax = plt.subplots(figsize=(14, 4))
|
| 347 |
+
|
| 348 |
+
print("\nListening... Press Ctrl+C to stop.\n")
|
| 349 |
+
|
| 350 |
+
try:
|
| 351 |
+
last_update = time.time()
|
| 352 |
+
|
| 353 |
+
while True:
|
| 354 |
+
# Get audio from queue
|
| 355 |
+
while not audio_queue.empty():
|
| 356 |
+
audio_chunk = audio_queue.get()
|
| 357 |
+
diarizer.add_audio(audio_chunk)
|
| 358 |
+
|
| 359 |
+
# Process
|
| 360 |
+
new_probs = diarizer.process()
|
| 361 |
+
|
| 362 |
+
# Update plot periodically
|
| 363 |
+
if time.time() - last_update > 0.16: # Update every 160ms
|
| 364 |
+
all_probs = diarizer.get_all_probs()
|
| 365 |
+
|
| 366 |
+
if all_probs is not None and len(all_probs) > 0:
|
| 367 |
+
ax.clear()
|
| 368 |
+
|
| 369 |
+
# Show last 200 frames (~16 seconds)
|
| 370 |
+
display_frames = min(200, len(all_probs))
|
| 371 |
+
display_probs = all_probs[-display_frames:]
|
| 372 |
+
|
| 373 |
+
sns.heatmap(
|
| 374 |
+
display_probs.T,
|
| 375 |
+
ax=ax,
|
| 376 |
+
cmap="viridis",
|
| 377 |
+
vmin=0, vmax=1,
|
| 378 |
+
yticklabels=[f"Spk {i}" for i in range(4)],
|
| 379 |
+
cbar=False
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
ax.set_xlabel("Time (frames, 80ms each)")
|
| 383 |
+
ax.set_ylabel("Speaker")
|
| 384 |
+
ax.set_title(f"Live Diarization - Total: {len(all_probs)} frames ({len(all_probs)*0.08:.1f}s)")
|
| 385 |
+
|
| 386 |
+
plt.draw()
|
| 387 |
+
plt.pause(0.01)
|
| 388 |
+
|
| 389 |
+
last_update = time.time()
|
| 390 |
+
|
| 391 |
+
time.sleep(0.01)
|
| 392 |
+
|
| 393 |
+
except KeyboardInterrupt:
|
| 394 |
+
print("\nStopping...")
|
| 395 |
+
finally:
|
| 396 |
+
mic.stop()
|
| 397 |
+
plt.ioff()
|
| 398 |
+
plt.close()
|
| 399 |
+
|
| 400 |
+
# Final summary
|
| 401 |
+
all_probs = diarizer.get_all_probs()
|
| 402 |
+
if all_probs is not None:
|
| 403 |
+
print(f"\nTotal processed: {len(all_probs)} frames ({len(all_probs)*0.08:.1f} seconds)")
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def run_file_demo(model_name, coreml_dir, audio_path):
|
| 407 |
+
"""Run demo on audio file with live updating plot."""
|
| 408 |
+
|
| 409 |
+
print("=" * 70)
|
| 410 |
+
print("File Demo with Live Updating Plot")
|
| 411 |
+
print("=" * 70)
|
| 412 |
+
|
| 413 |
+
# Load NeMo model
|
| 414 |
+
print(f"\nLoading NeMo Model: {model_name}")
|
| 415 |
+
nemo_model = SortformerEncLabelModel.from_pretrained(model_name, map_location="cpu")
|
| 416 |
+
nemo_model.eval()
|
| 417 |
+
|
| 418 |
+
# Configure
|
| 419 |
+
modules = nemo_model.sortformer_modules
|
| 420 |
+
modules.chunk_len = CONFIG['chunk_len']
|
| 421 |
+
modules.chunk_right_context = CONFIG['chunk_right_context']
|
| 422 |
+
modules.chunk_left_context = CONFIG['chunk_left_context']
|
| 423 |
+
modules.fifo_len = CONFIG['fifo_len']
|
| 424 |
+
modules.spkcache_len = CONFIG['spkcache_len']
|
| 425 |
+
modules.spkcache_update_period = CONFIG['spkcache_update_period']
|
| 426 |
+
|
| 427 |
+
if hasattr(nemo_model.preprocessor, 'featurizer'):
|
| 428 |
+
nemo_model.preprocessor.featurizer.dither = 0.0
|
| 429 |
+
nemo_model.preprocessor.featurizer.pad_to = 0
|
| 430 |
+
|
| 431 |
+
# Load CoreML models
|
| 432 |
+
print(f"Loading CoreML Models from {coreml_dir}...")
|
| 433 |
+
preproc_model = ct.models.MLModel(
|
| 434 |
+
os.path.join(coreml_dir, "Pipeline_Preprocessor.mlpackage"),
|
| 435 |
+
compute_units=ct.ComputeUnit.CPU_ONLY
|
| 436 |
+
)
|
| 437 |
+
main_model = ct.models.MLModel(
|
| 438 |
+
os.path.join(coreml_dir, "SortformerPipeline.mlpackage"),
|
| 439 |
+
compute_units=ct.ComputeUnit.ALL
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# Load audio file
|
| 443 |
+
import librosa
|
| 444 |
+
audio, _ = librosa.load(audio_path, sr=CONFIG['sample_rate'], mono=True)
|
| 445 |
+
print(f"Loaded audio: {len(audio)} samples ({len(audio)/CONFIG['sample_rate']:.1f}s)")
|
| 446 |
+
|
| 447 |
+
# Create diarizer
|
| 448 |
+
diarizer = StreamingDiarizer(nemo_model, preproc_model, main_model, CONFIG)
|
| 449 |
+
|
| 450 |
+
# Setup plot
|
| 451 |
+
plt.ion()
|
| 452 |
+
fig, ax = plt.subplots(figsize=(14, 4))
|
| 453 |
+
|
| 454 |
+
# Simulate streaming
|
| 455 |
+
chunk_size = CONFIG['audio_chunk_samples']
|
| 456 |
+
offset = 0
|
| 457 |
+
|
| 458 |
+
print("\nStreaming audio with live plot...")
|
| 459 |
+
|
| 460 |
+
try:
|
| 461 |
+
while offset < len(audio):
|
| 462 |
+
# Add audio chunk
|
| 463 |
+
chunk_end = min(offset + chunk_size, len(audio))
|
| 464 |
+
audio_chunk = audio[offset:chunk_end]
|
| 465 |
+
diarizer.add_audio(audio_chunk)
|
| 466 |
+
offset = chunk_end
|
| 467 |
+
|
| 468 |
+
# Process
|
| 469 |
+
diarizer.process()
|
| 470 |
+
|
| 471 |
+
# Update plot
|
| 472 |
+
all_probs = diarizer.get_all_probs()
|
| 473 |
+
|
| 474 |
+
if all_probs is not None and len(all_probs) > 0:
|
| 475 |
+
ax.clear()
|
| 476 |
+
|
| 477 |
+
sns.heatmap(
|
| 478 |
+
all_probs.T,
|
| 479 |
+
ax=ax,
|
| 480 |
+
cmap="viridis",
|
| 481 |
+
vmin=0, vmax=1,
|
| 482 |
+
yticklabels=[f"Spk {i}" for i in range(4)],
|
| 483 |
+
cbar=False
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
ax.set_xlabel("Time (frames, 80ms each)")
|
| 487 |
+
ax.set_ylabel("Speaker")
|
| 488 |
+
ax.set_title(f"Streaming Diarization - {len(all_probs)} frames")
|
| 489 |
+
|
| 490 |
+
plt.draw()
|
| 491 |
+
plt.pause(0.05)
|
| 492 |
+
|
| 493 |
+
# Simulate real-time (optional - comment out for fast mode)
|
| 494 |
+
# time.sleep(chunk_size / CONFIG['sample_rate'])
|
| 495 |
+
|
| 496 |
+
except KeyboardInterrupt:
|
| 497 |
+
print("\nStopped.")
|
| 498 |
+
|
| 499 |
+
plt.ioff()
|
| 500 |
+
|
| 501 |
+
# Final plot
|
| 502 |
+
all_probs = diarizer.get_all_probs()
|
| 503 |
+
if all_probs is not None:
|
| 504 |
+
print(f"\nTotal: {len(all_probs)} frames ({len(all_probs)*0.08:.1f}s)")
|
| 505 |
+
plt.show()
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
if __name__ == "__main__":
|
| 509 |
+
parser = argparse.ArgumentParser()
|
| 510 |
+
parser.add_argument("--model_name", default="nvidia/diar_streaming_sortformer_4spk-v2.1")
|
| 511 |
+
parser.add_argument("--coreml_dir", default="coreml_models")
|
| 512 |
+
parser.add_argument("--audio_path", default="audio.wav")
|
| 513 |
+
parser.add_argument("--mic", action="store_true", help="Use microphone input")
|
| 514 |
+
args = parser.parse_args()
|
| 515 |
+
|
| 516 |
+
run_mic_inference(args.model_name, args.coreml_dir)
|
| 517 |
+
# if args.mic:
|
| 518 |
+
# else:
|
| 519 |
+
# run_file_demo(args.model_name, args.coreml_dir, args.audio_path)
|
nemo_streaming_reference.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Get exact NeMo streaming inference output for comparison with Swift."""
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import librosa
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
from nemo.collections.asr.models import SortformerEncLabelModel
|
| 13 |
+
|
| 14 |
+
def main():
|
| 15 |
+
print("Loading NeMo model...")
|
| 16 |
+
model = SortformerEncLabelModel.restore_from(
|
| 17 |
+
'diar_streaming_sortformer_4spk-v2.nemo', map_location='cpu'
|
| 18 |
+
)
|
| 19 |
+
model.eval()
|
| 20 |
+
|
| 21 |
+
# Disable dither for deterministic output
|
| 22 |
+
if hasattr(model.preprocessor, 'featurizer'):
|
| 23 |
+
if hasattr(model.preprocessor.featurizer, 'dither'):
|
| 24 |
+
model.preprocessor.featurizer.dither = 0.0
|
| 25 |
+
|
| 26 |
+
# Configure for Gradient Descent's streaming config (same as Swift)
|
| 27 |
+
modules = model.sortformer_modules
|
| 28 |
+
modules.chunk_len = 6
|
| 29 |
+
modules.chunk_left_context = 1
|
| 30 |
+
modules.chunk_right_context = 7
|
| 31 |
+
modules.fifo_len = 40
|
| 32 |
+
modules.spkcache_len = 188
|
| 33 |
+
modules.spkcache_update_period = 31
|
| 34 |
+
|
| 35 |
+
print(f"Config: chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, right_ctx={modules.chunk_right_context}")
|
| 36 |
+
print(f" fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}")
|
| 37 |
+
|
| 38 |
+
# Load audio
|
| 39 |
+
audio_path = "../audio.wav"
|
| 40 |
+
audio, sr = librosa.load(audio_path, sr=16000, mono=True)
|
| 41 |
+
print(f"Loaded audio: {len(audio)} samples ({len(audio)/16000:.2f}s)")
|
| 42 |
+
|
| 43 |
+
waveform = torch.from_numpy(audio).unsqueeze(0).float()
|
| 44 |
+
|
| 45 |
+
# Get mel features using model's preprocessor
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
audio_len = torch.tensor([waveform.shape[1]])
|
| 48 |
+
features, feat_len = model.process_signal(
|
| 49 |
+
audio_signal=waveform, audio_signal_length=audio_len
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# features is [batch, mel, time], need [batch, time, mel] for streaming
|
| 53 |
+
features = features[:, :, :feat_len.max()]
|
| 54 |
+
print(f"Features: {features.shape} (batch, mel, time)")
|
| 55 |
+
|
| 56 |
+
# Streaming inference using forward_streaming_step
|
| 57 |
+
subsampling = modules.subsampling_factor # 8
|
| 58 |
+
chunk_len = modules.chunk_len # 6
|
| 59 |
+
left_context = modules.chunk_left_context # 1
|
| 60 |
+
right_context = modules.chunk_right_context # 7
|
| 61 |
+
core_frames = chunk_len * subsampling # 48 mel frames
|
| 62 |
+
|
| 63 |
+
total_mel_frames = features.shape[2]
|
| 64 |
+
print(f"Total mel frames: {total_mel_frames}")
|
| 65 |
+
print(f"Core frames per chunk: {core_frames}")
|
| 66 |
+
|
| 67 |
+
# Initialize streaming state
|
| 68 |
+
streaming_state = modules.init_streaming_state(device=features.device)
|
| 69 |
+
|
| 70 |
+
# Initialize total_preds tensor
|
| 71 |
+
total_preds = torch.zeros((1, 0, 4), device=features.device)
|
| 72 |
+
|
| 73 |
+
all_preds = []
|
| 74 |
+
chunk_idx = 0
|
| 75 |
+
|
| 76 |
+
# Process chunks like streaming_feat_loader
|
| 77 |
+
stt_feat = 0
|
| 78 |
+
while stt_feat < total_mel_frames:
|
| 79 |
+
end_feat = min(stt_feat + core_frames, total_mel_frames)
|
| 80 |
+
|
| 81 |
+
# Calculate context (in mel frames)
|
| 82 |
+
left_offset = min(left_context * subsampling, stt_feat)
|
| 83 |
+
right_offset = min(right_context * subsampling, total_mel_frames - end_feat)
|
| 84 |
+
|
| 85 |
+
chunk_start = stt_feat - left_offset
|
| 86 |
+
chunk_end = end_feat + right_offset
|
| 87 |
+
|
| 88 |
+
# Extract chunk - [batch, mel, time] -> [batch, time, mel]
|
| 89 |
+
chunk = features[:, :, chunk_start:chunk_end] # [1, 128, T]
|
| 90 |
+
chunk_t = chunk.transpose(1, 2) # [1, T, 128]
|
| 91 |
+
chunk_len_tensor = torch.tensor([chunk_t.shape[1]], dtype=torch.long)
|
| 92 |
+
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
# Use forward_streaming_step
|
| 95 |
+
streaming_state, total_preds = model.forward_streaming_step(
|
| 96 |
+
processed_signal=chunk_t,
|
| 97 |
+
processed_signal_length=chunk_len_tensor,
|
| 98 |
+
streaming_state=streaming_state,
|
| 99 |
+
total_preds=total_preds,
|
| 100 |
+
left_offset=left_offset,
|
| 101 |
+
right_offset=right_offset,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
chunk_idx += 1
|
| 105 |
+
stt_feat = end_feat
|
| 106 |
+
|
| 107 |
+
# total_preds now contains all predictions
|
| 108 |
+
all_preds = total_preds[0].numpy() # [total_frames, 4]
|
| 109 |
+
print(f"\nTotal output frames: {all_preds.shape[0]}")
|
| 110 |
+
print(f"Predictions shape: {all_preds.shape}")
|
| 111 |
+
|
| 112 |
+
# Print timeline
|
| 113 |
+
print("\n=== NeMo Streaming Timeline (80ms per frame, threshold=0.55) ===")
|
| 114 |
+
print("Frame Time Spk0 Spk1 Spk2 Spk3 | Visual")
|
| 115 |
+
print("-" * 60)
|
| 116 |
+
|
| 117 |
+
for frame in range(all_preds.shape[0]):
|
| 118 |
+
time_sec = frame * 0.08
|
| 119 |
+
probs = all_preds[frame]
|
| 120 |
+
visual = ['■' if p > 0.55 else '·' for p in probs]
|
| 121 |
+
print(f"{frame:5d} {time_sec:5.2f}s {probs[0]:.3f} {probs[1]:.3f} {probs[2]:.3f} {probs[3]:.3f} | [{visual[0]}{visual[1]}{visual[2]}{visual[3]}]")
|
| 122 |
+
|
| 123 |
+
print("-" * 60)
|
| 124 |
+
|
| 125 |
+
# Speaker activity summary
|
| 126 |
+
print("\n=== Speaker Activity Summary ===")
|
| 127 |
+
threshold = 0.55
|
| 128 |
+
for spk in range(4):
|
| 129 |
+
active_frames = np.sum(all_preds[:, spk] > threshold)
|
| 130 |
+
active_time = active_frames * 0.08
|
| 131 |
+
percent = active_time / (all_preds.shape[0] * 0.08) * 100
|
| 132 |
+
print(f"Speaker_{spk}: {active_time:.1f}s active ({percent:.1f}%)")
|
| 133 |
+
|
| 134 |
+
# Save to JSON for comparison
|
| 135 |
+
output = {
|
| 136 |
+
"total_frames": int(all_preds.shape[0]),
|
| 137 |
+
"frame_duration_seconds": 0.08,
|
| 138 |
+
"probabilities": all_preds.flatten().tolist(),
|
| 139 |
+
"config": {
|
| 140 |
+
"chunk_len": chunk_len,
|
| 141 |
+
"chunk_left_context": left_context,
|
| 142 |
+
"chunk_right_context": right_context,
|
| 143 |
+
"fifo_len": modules.fifo_len,
|
| 144 |
+
"spkcache_len": modules.spkcache_len,
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
with open("/tmp/nemo_streaming_reference.json", "w") as f:
|
| 149 |
+
json.dump(output, f, indent=2)
|
| 150 |
+
print("\nSaved to /tmp/nemo_streaming_reference.json")
|
| 151 |
+
|
| 152 |
+
if __name__ == "__main__":
|
| 153 |
+
main()
|
streaming_inference.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import coremltools as ct
|
| 4 |
+
import librosa
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
# Import NeMo components for State Logic
|
| 11 |
+
try:
|
| 12 |
+
from nemo.collections.asr.models import SortformerEncLabelModel
|
| 13 |
+
# Try importing SortformerModules directly for type hints if needed, but we can access via model instance
|
| 14 |
+
from nemo.collections.asr.modules.sortformer_modules import SortformerModules
|
| 15 |
+
except ImportError as e:
|
| 16 |
+
print(f"Error importing NeMo: {e}")
|
| 17 |
+
sys.exit(1)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def streaming_feat_loader(modules, feat_seq, feat_seq_length, feat_seq_offset):
|
| 21 |
+
"""
|
| 22 |
+
Load a chunk of feature sequence for streaming inference.
|
| 23 |
+
Adapted from NeMo's SortformerModules.streaming_feat_loader
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
modules: SortformerModules instance with chunk_len, subsampling_factor,
|
| 27 |
+
chunk_left_context, chunk_right_context
|
| 28 |
+
feat_seq (torch.Tensor): Tensor containing feature sequence
|
| 29 |
+
Shape: (batch_size, feat_dim, feat frame count)
|
| 30 |
+
feat_seq_length (torch.Tensor): Tensor containing feature sequence lengths
|
| 31 |
+
Shape: (batch_size,)
|
| 32 |
+
feat_seq_offset (torch.Tensor): Tensor containing feature sequence offsets
|
| 33 |
+
Shape: (batch_size,)
|
| 34 |
+
|
| 35 |
+
Yields:
|
| 36 |
+
chunk_idx (int): Index of the current chunk
|
| 37 |
+
chunk_feat_seq (torch.Tensor): Tensor containing the chunk of feature sequence
|
| 38 |
+
Shape: (batch_size, feat frame count, feat_dim) # Transposed!
|
| 39 |
+
feat_lengths (torch.Tensor): Tensor containing lengths of the chunk of feature sequence
|
| 40 |
+
Shape: (batch_size,)
|
| 41 |
+
left_offset (int): Left context offset in feature frames
|
| 42 |
+
right_offset (int): Right context offset in feature frames
|
| 43 |
+
"""
|
| 44 |
+
feat_len = feat_seq.shape[2]
|
| 45 |
+
chunk_len = modules.chunk_len
|
| 46 |
+
subsampling_factor = modules.subsampling_factor
|
| 47 |
+
chunk_left_context = getattr(modules, 'chunk_left_context', 0)
|
| 48 |
+
chunk_right_context = getattr(modules, 'chunk_right_context', 0)
|
| 49 |
+
|
| 50 |
+
num_chunks = math.ceil(feat_len / (chunk_len * subsampling_factor))
|
| 51 |
+
print(f"streaming_feat_loader: feat_len={feat_len}, num_chunks={num_chunks}, "
|
| 52 |
+
f"chunk_len={chunk_len}, subsampling_factor={subsampling_factor}")
|
| 53 |
+
|
| 54 |
+
stt_feat, end_feat, chunk_idx = 0, 0, 0
|
| 55 |
+
while end_feat < feat_len:
|
| 56 |
+
left_offset = min(chunk_left_context * subsampling_factor, stt_feat)
|
| 57 |
+
end_feat = min(stt_feat + chunk_len * subsampling_factor, feat_len)
|
| 58 |
+
right_offset = min(chunk_right_context * subsampling_factor, feat_len - end_feat)
|
| 59 |
+
|
| 60 |
+
chunk_feat_seq = feat_seq[:, :, stt_feat - left_offset : end_feat + right_offset]
|
| 61 |
+
feat_lengths = (feat_seq_length + feat_seq_offset - stt_feat + left_offset).clamp(
|
| 62 |
+
0, chunk_feat_seq.shape[2]
|
| 63 |
+
)
|
| 64 |
+
feat_lengths = feat_lengths * (feat_seq_offset < end_feat)
|
| 65 |
+
stt_feat = end_feat
|
| 66 |
+
|
| 67 |
+
# Transpose from (batch, feat_dim, frames) to (batch, frames, feat_dim)
|
| 68 |
+
chunk_feat_seq_t = torch.transpose(chunk_feat_seq, 1, 2)
|
| 69 |
+
|
| 70 |
+
print(f" chunk_idx: {chunk_idx}, chunk_feat_seq_t shape: {chunk_feat_seq_t.shape}, "
|
| 71 |
+
f"feat_lengths: {feat_lengths}, left_offset: {left_offset}, right_offset: {right_offset}")
|
| 72 |
+
|
| 73 |
+
yield chunk_idx, chunk_feat_seq_t, feat_lengths, left_offset, right_offset
|
| 74 |
+
chunk_idx += 1
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def run_streaming_inference(model_name, coreml_dir, audio_path):
|
| 78 |
+
print(f"Loading NeMo Model (for Python Streaming Logic): {model_name}")
|
| 79 |
+
if os.path.exists(model_name):
|
| 80 |
+
nemo_model = SortformerEncLabelModel.restore_from(model_name, map_location="cpu")
|
| 81 |
+
else:
|
| 82 |
+
nemo_model = SortformerEncLabelModel.from_pretrained(model_name, map_location="cpu")
|
| 83 |
+
nemo_model.eval()
|
| 84 |
+
modules = nemo_model.sortformer_modules
|
| 85 |
+
|
| 86 |
+
# --- Override Config to match CoreML Export (Low Latency) ---
|
| 87 |
+
print("Overriding Config (Inference) to match CoreML...")
|
| 88 |
+
modules.chunk_len = 4
|
| 89 |
+
modules.chunk_right_context = 1 # 1 chunk of right context
|
| 90 |
+
modules.chunk_left_context = 2 # 1 chunk of left context
|
| 91 |
+
# Match CoreML export sizes (from model spec)
|
| 92 |
+
modules.fifo_len = 63
|
| 93 |
+
modules.spkcache_len = 63
|
| 94 |
+
modules.spkcache_update_period = 50 # Match CoreML export
|
| 95 |
+
|
| 96 |
+
# CoreML fixed input sizes (must match export settings)
|
| 97 |
+
# With left_context=1, right_context=1: (4+1+1)*8 = 48 frames
|
| 98 |
+
COREML_CHUNK_FRAMES = 56
|
| 99 |
+
COREML_SPKCACHE_LEN = 63
|
| 100 |
+
COREML_FIFO_LEN = 63
|
| 101 |
+
|
| 102 |
+
# Disable dither and pad_to (as diarize does)
|
| 103 |
+
if hasattr(nemo_model.preprocessor, 'featurizer'):
|
| 104 |
+
if hasattr(nemo_model.preprocessor.featurizer, 'dither'):
|
| 105 |
+
nemo_model.preprocessor.featurizer.dither = 0.0
|
| 106 |
+
if hasattr(nemo_model.preprocessor.featurizer, 'pad_to'):
|
| 107 |
+
nemo_model.preprocessor.featurizer.pad_to = 0
|
| 108 |
+
|
| 109 |
+
# CoreML Models - use CPU_ONLY for compatibility
|
| 110 |
+
print(f"Loading CoreML Models from {coreml_dir}...")
|
| 111 |
+
preproc_model = ct.models.MLModel(
|
| 112 |
+
os.path.join(coreml_dir, "SortformerPreprocessor.mlpackage"),
|
| 113 |
+
compute_units=ct.ComputeUnit.CPU_ONLY
|
| 114 |
+
)
|
| 115 |
+
main_model = ct.models.MLModel(
|
| 116 |
+
os.path.join(coreml_dir, "Sortformer.mlpackage"),
|
| 117 |
+
compute_units=ct.ComputeUnit.ALL
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Config
|
| 121 |
+
chunk_len = modules.chunk_len # Output frames (e.g., 4 for low latency)
|
| 122 |
+
subsampling_factor = modules.subsampling_factor # 8
|
| 123 |
+
sample_rate = 16000
|
| 124 |
+
|
| 125 |
+
print(f"Chunk Config: {chunk_len} output frames (diar), subsampling_factor={subsampling_factor}")
|
| 126 |
+
|
| 127 |
+
# Load Audio
|
| 128 |
+
print(f"Loading Audio: {audio_path}")
|
| 129 |
+
full_audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True)
|
| 130 |
+
total_samples = len(full_audio)
|
| 131 |
+
print(f"Total Samples: {total_samples} ({total_samples/sample_rate:.2f}s)")
|
| 132 |
+
|
| 133 |
+
# === Step 1: Extract features for the ENTIRE audio using preprocessor ===
|
| 134 |
+
# This matches NeMo's approach: process_signal -> forward_streaming
|
| 135 |
+
print("Extracting features for entire audio...")
|
| 136 |
+
audio_tensor = torch.from_numpy(full_audio).unsqueeze(0).float() # [1, samples]
|
| 137 |
+
audio_length = torch.tensor([total_samples], dtype=torch.long)
|
| 138 |
+
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
# Use process_signal for proper normalization (same as forward())
|
| 141 |
+
processed_signal, processed_signal_length = nemo_model.process_signal(
|
| 142 |
+
audio_signal=audio_tensor, audio_signal_length=audio_length
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
print(f"Processed signal shape: {processed_signal.shape}") # [1, 128, T]
|
| 146 |
+
print(f"Processed signal length: {processed_signal_length}")
|
| 147 |
+
|
| 148 |
+
# Trim to actual length
|
| 149 |
+
processed_signal = processed_signal[:, :, :processed_signal_length.max()]
|
| 150 |
+
|
| 151 |
+
# === Step 2: Initialize streaming state ===
|
| 152 |
+
print("Initializing Streaming State...")
|
| 153 |
+
state = modules.init_streaming_state(batch_size=1, device='cpu')
|
| 154 |
+
|
| 155 |
+
# === Step 3: Use streaming_feat_loader to chunk features (matches NeMo exactly) ===
|
| 156 |
+
batch_size = processed_signal.shape[0]
|
| 157 |
+
processed_signal_offset = torch.zeros((batch_size,), dtype=torch.long)
|
| 158 |
+
|
| 159 |
+
all_preds = []
|
| 160 |
+
|
| 161 |
+
feat_loader = streaming_feat_loader(
|
| 162 |
+
modules=modules,
|
| 163 |
+
feat_seq=processed_signal,
|
| 164 |
+
feat_seq_length=processed_signal_length,
|
| 165 |
+
feat_seq_offset=processed_signal_offset,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
for chunk_idx, chunk_feat_seq_t, feat_lengths, left_offset, right_offset in feat_loader:
|
| 169 |
+
# Prepare inputs for CoreML model
|
| 170 |
+
# Pad chunk to fixed size for CoreML
|
| 171 |
+
chunk_actual_len = chunk_feat_seq_t.shape[1]
|
| 172 |
+
if chunk_actual_len < COREML_CHUNK_FRAMES:
|
| 173 |
+
pad_len = COREML_CHUNK_FRAMES - chunk_actual_len
|
| 174 |
+
chunk_in = torch.nn.functional.pad(chunk_feat_seq_t, (0, 0, 0, pad_len))
|
| 175 |
+
else:
|
| 176 |
+
chunk_in = chunk_feat_seq_t[:, :COREML_CHUNK_FRAMES, :]
|
| 177 |
+
chunk_len_in = feat_lengths.long() # actual length
|
| 178 |
+
|
| 179 |
+
# Get actual lengths from state (pad tensors but track real lengths)
|
| 180 |
+
curr_spk_len = state.spkcache.shape[1]
|
| 181 |
+
curr_fifo_len = state.fifo.shape[1]
|
| 182 |
+
# Prepare SpkCache - Pad to CoreML fixed size
|
| 183 |
+
current_spkcache = state.spkcache
|
| 184 |
+
|
| 185 |
+
if curr_spk_len < COREML_SPKCACHE_LEN:
|
| 186 |
+
pad_len = COREML_SPKCACHE_LEN - curr_spk_len
|
| 187 |
+
current_spkcache = torch.nn.functional.pad(current_spkcache, (0, 0, 0, pad_len))
|
| 188 |
+
elif curr_spk_len > COREML_SPKCACHE_LEN:
|
| 189 |
+
current_spkcache = current_spkcache[:, :COREML_SPKCACHE_LEN, :]
|
| 190 |
+
|
| 191 |
+
spkcache_in = current_spkcache
|
| 192 |
+
# Use actual length, not padded length
|
| 193 |
+
spkcache_len_in = torch.tensor([curr_spk_len], dtype=torch.long)
|
| 194 |
+
|
| 195 |
+
# Prepare FIFO - Pad to CoreML fixed size
|
| 196 |
+
current_fifo = state.fifo
|
| 197 |
+
|
| 198 |
+
if curr_fifo_len < COREML_FIFO_LEN:
|
| 199 |
+
pad_len = COREML_FIFO_LEN - curr_fifo_len
|
| 200 |
+
current_fifo = torch.nn.functional.pad(current_fifo, (0, 0, 0, pad_len))
|
| 201 |
+
elif curr_fifo_len > COREML_FIFO_LEN:
|
| 202 |
+
current_fifo = current_fifo[:, :COREML_FIFO_LEN, :]
|
| 203 |
+
|
| 204 |
+
fifo_in = current_fifo
|
| 205 |
+
fifo_len_in = torch.tensor([curr_fifo_len], dtype=torch.long)
|
| 206 |
+
|
| 207 |
+
# === Run CoreML Model ===
|
| 208 |
+
coreml_inputs = {
|
| 209 |
+
"chunk": chunk_in.numpy().astype(np.float32),
|
| 210 |
+
"chunk_lengths": chunk_len_in.numpy().astype(np.int32),
|
| 211 |
+
"spkcache": spkcache_in.numpy().astype(np.float32),
|
| 212 |
+
"spkcache_lengths": spkcache_len_in.numpy().astype(np.int32),
|
| 213 |
+
"fifo": fifo_in.numpy().astype(np.float32),
|
| 214 |
+
"fifo_lengths": fifo_len_in.numpy().astype(np.int32)
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
coreml_out = main_model.predict(coreml_inputs)
|
| 218 |
+
|
| 219 |
+
# Convert outputs back to torch tensors
|
| 220 |
+
pred_logits = torch.from_numpy(coreml_out["speaker_preds"])
|
| 221 |
+
chunk_embs = torch.from_numpy(coreml_out["chunk_pre_encoder_embs"])
|
| 222 |
+
chunk_emb_len = int(coreml_out["chunk_pre_encoder_lengths"][0])
|
| 223 |
+
|
| 224 |
+
# Trim chunk_embs to actual length (drop padded frames)
|
| 225 |
+
chunk_embs = chunk_embs[:, :chunk_emb_len, :]
|
| 226 |
+
|
| 227 |
+
# Compute lc and rc for streaming_update (in embeddings/diar frames, not feature frames)
|
| 228 |
+
# NeMo does: lc = round(left_offset / encoder.subsampling_factor)
|
| 229 |
+
# rc = math.ceil(right_offset / encoder.subsampling_factor)
|
| 230 |
+
lc = round(left_offset / subsampling_factor)
|
| 231 |
+
rc = math.ceil(right_offset / subsampling_factor)
|
| 232 |
+
|
| 233 |
+
# Update state using streaming_update with proper lc/rc
|
| 234 |
+
state, chunk_probs = modules.streaming_update(
|
| 235 |
+
streaming_state=state,
|
| 236 |
+
chunk=chunk_embs,
|
| 237 |
+
preds=pred_logits,
|
| 238 |
+
lc=lc,
|
| 239 |
+
rc=rc
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# chunk_probs is the prediction for the current chunk
|
| 243 |
+
all_preds.append(chunk_probs)
|
| 244 |
+
|
| 245 |
+
print(f"Processed chunk {chunk_idx + 1}, chunk_probs shape: {chunk_probs.shape}", end='\r')
|
| 246 |
+
|
| 247 |
+
print(f"\nFinished. Total Chunks: {len(all_preds)}")
|
| 248 |
+
if len(all_preds) > 0:
|
| 249 |
+
final_probs = torch.cat(all_preds, dim=1) # [1, TotalFrames, Spks]
|
| 250 |
+
print(f"Final Predictions Shape: {final_probs.shape}")
|
| 251 |
+
return final_probs
|
| 252 |
+
return None
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
if __name__ == "__main__":
|
| 256 |
+
parser = argparse.ArgumentParser()
|
| 257 |
+
parser.add_argument("--model_name", default="nvidia/diar_streaming_sortformer_4spk-v2.1")
|
| 258 |
+
parser.add_argument("--coreml_dir", default="coreml_models")
|
| 259 |
+
parser.add_argument("--audio_path", default="test2.wav")
|
| 260 |
+
args = parser.parse_args()
|
| 261 |
+
|
| 262 |
+
run_streaming_inference(args.model_name, args.coreml_dir, args.audio_path)
|
streaming_preproc_inference.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
True Streaming CoreML Diarization
|
| 3 |
+
|
| 4 |
+
This script implements true streaming inference:
|
| 5 |
+
Audio chunks → CoreML Preprocessor → Feature Buffer → CoreML Main Model → Predictions
|
| 6 |
+
|
| 7 |
+
Audio is processed incrementally, features are accumulated with proper context handling.
|
| 8 |
+
"""
|
| 9 |
+
import os
|
| 10 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import numpy as np
|
| 14 |
+
import coremltools as ct
|
| 15 |
+
import librosa
|
| 16 |
+
import argparse
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
# Import NeMo for state management (streaming_update) only
|
| 20 |
+
from nemo.collections.asr.models import SortformerEncLabelModel
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ============================================================
|
| 24 |
+
# Configuration for Sortformer16.mlpackage
|
| 25 |
+
# ============================================================
|
| 26 |
+
CONFIG = {
|
| 27 |
+
'chunk_len': 4, # Diarization chunk length
|
| 28 |
+
'chunk_right_context': 1, # Right context chunks
|
| 29 |
+
'chunk_left_context': 2, # Left context chunks
|
| 30 |
+
'fifo_len': 63,
|
| 31 |
+
'spkcache_len': 63,
|
| 32 |
+
'spkcache_update_period': 50,
|
| 33 |
+
'subsampling_factor': 8,
|
| 34 |
+
'sample_rate': 16000,
|
| 35 |
+
|
| 36 |
+
# Derived values
|
| 37 |
+
'chunk_frames': 56, # (4+2+1)*8 = 56 feature frames for CoreML input
|
| 38 |
+
'spkcache_input_len': 63,
|
| 39 |
+
'fifo_input_len': 63,
|
| 40 |
+
|
| 41 |
+
# Preprocessor settings
|
| 42 |
+
'preproc_audio_samples': 9200, # CoreML preprocessor fixed input size
|
| 43 |
+
'mel_window': 400, # 25ms @ 16kHz
|
| 44 |
+
'mel_stride': 160, # 10ms @ 16kHz
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def run_true_streaming(nemo_model, preproc_model, main_model, audio_path, config):
|
| 49 |
+
"""
|
| 50 |
+
True streaming inference: audio chunks → preproc → main model.
|
| 51 |
+
|
| 52 |
+
Strategy:
|
| 53 |
+
1. Process audio in chunks through CoreML preprocessor
|
| 54 |
+
2. Accumulate features
|
| 55 |
+
3. When enough features for a diarization chunk (with context), run main model
|
| 56 |
+
"""
|
| 57 |
+
modules = nemo_model.sortformer_modules
|
| 58 |
+
subsampling_factor = config['subsampling_factor']
|
| 59 |
+
|
| 60 |
+
# Load full audio (simulating microphone input)
|
| 61 |
+
full_audio, sr = librosa.load(audio_path, sr=config['sample_rate'], mono=True)
|
| 62 |
+
total_samples = len(full_audio)
|
| 63 |
+
|
| 64 |
+
print(f"Total audio samples: {total_samples}")
|
| 65 |
+
|
| 66 |
+
# Preprocessing parameters
|
| 67 |
+
mel_window = config['mel_window']
|
| 68 |
+
mel_stride = config['mel_stride']
|
| 69 |
+
preproc_len = config['preproc_audio_samples']
|
| 70 |
+
|
| 71 |
+
# Audio hop for preprocessor (to avoid overlap in features)
|
| 72 |
+
audio_hop = preproc_len - mel_window # 8800 samples
|
| 73 |
+
|
| 74 |
+
# Feature accumulator
|
| 75 |
+
all_features = []
|
| 76 |
+
audio_offset = 0
|
| 77 |
+
preproc_chunk_idx = 0
|
| 78 |
+
|
| 79 |
+
# Step 1: Process all audio through preprocessor to get features
|
| 80 |
+
print("Step 1: Extracting features via CoreML preprocessor...")
|
| 81 |
+
while audio_offset < total_samples:
|
| 82 |
+
# Get audio chunk
|
| 83 |
+
chunk_end = min(audio_offset + preproc_len, total_samples)
|
| 84 |
+
audio_chunk = full_audio[audio_offset:chunk_end]
|
| 85 |
+
actual_samples = len(audio_chunk)
|
| 86 |
+
|
| 87 |
+
# Pad if needed
|
| 88 |
+
if actual_samples < preproc_len:
|
| 89 |
+
audio_chunk = np.pad(audio_chunk, (0, preproc_len - actual_samples))
|
| 90 |
+
|
| 91 |
+
# Run preprocessor
|
| 92 |
+
preproc_inputs = {
|
| 93 |
+
"audio_signal": audio_chunk.reshape(1, -1).astype(np.float32),
|
| 94 |
+
"length": np.array([actual_samples], dtype=np.int32)
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
preproc_out = preproc_model.predict(preproc_inputs)
|
| 98 |
+
feat_chunk = np.array(preproc_out["features"]) # [1, 128, frames]
|
| 99 |
+
feat_len = int(preproc_out["feature_lengths"][0])
|
| 100 |
+
|
| 101 |
+
# Extract valid features and handle overlap
|
| 102 |
+
if preproc_chunk_idx == 0:
|
| 103 |
+
# First chunk: keep all
|
| 104 |
+
valid_feats = feat_chunk[:, :, :feat_len]
|
| 105 |
+
else:
|
| 106 |
+
# Subsequent: skip overlap frames
|
| 107 |
+
overlap_frames = (mel_window - mel_stride) // mel_stride + 1 # ~2-3 frames
|
| 108 |
+
valid_feats = feat_chunk[:, :, overlap_frames:feat_len]
|
| 109 |
+
|
| 110 |
+
all_features.append(valid_feats)
|
| 111 |
+
|
| 112 |
+
audio_offset += audio_hop
|
| 113 |
+
preproc_chunk_idx += 1
|
| 114 |
+
|
| 115 |
+
print(f"\r Processed audio chunk {preproc_chunk_idx}, features so far: {sum(f.shape[2] for f in all_features)}", end='')
|
| 116 |
+
|
| 117 |
+
print()
|
| 118 |
+
|
| 119 |
+
# Concatenate all features
|
| 120 |
+
full_features = np.concatenate(all_features, axis=2) # [1, 128, total_frames]
|
| 121 |
+
processed_signal = torch.from_numpy(full_features).float()
|
| 122 |
+
processed_signal_length = torch.tensor([full_features.shape[2]], dtype=torch.long)
|
| 123 |
+
|
| 124 |
+
print(f"Total features extracted: {processed_signal.shape}")
|
| 125 |
+
|
| 126 |
+
# Step 2: Run diarization streaming loop (same as NeMo reference)
|
| 127 |
+
print("Step 2: Running diarization streaming...")
|
| 128 |
+
|
| 129 |
+
state = modules.init_streaming_state(batch_size=1, device='cpu')
|
| 130 |
+
all_preds = []
|
| 131 |
+
|
| 132 |
+
feat_len = processed_signal.shape[2]
|
| 133 |
+
chunk_len = modules.chunk_len
|
| 134 |
+
left_ctx = modules.chunk_left_context
|
| 135 |
+
right_ctx = modules.chunk_right_context
|
| 136 |
+
|
| 137 |
+
stt_feat, end_feat, chunk_idx = 0, 0, 0
|
| 138 |
+
|
| 139 |
+
while end_feat < feat_len:
|
| 140 |
+
left_offset = min(left_ctx * subsampling_factor, stt_feat)
|
| 141 |
+
end_feat = min(stt_feat + chunk_len * subsampling_factor, feat_len)
|
| 142 |
+
right_offset = min(right_ctx * subsampling_factor, feat_len - end_feat)
|
| 143 |
+
|
| 144 |
+
# Extract chunk with context
|
| 145 |
+
chunk_feat = processed_signal[:, :, stt_feat - left_offset : end_feat + right_offset]
|
| 146 |
+
actual_len = chunk_feat.shape[2]
|
| 147 |
+
|
| 148 |
+
# Transpose to [B, T, D]
|
| 149 |
+
chunk_t = chunk_feat.transpose(1, 2)
|
| 150 |
+
|
| 151 |
+
# Pad to fixed size
|
| 152 |
+
if actual_len < config['chunk_frames']:
|
| 153 |
+
pad_len = config['chunk_frames'] - actual_len
|
| 154 |
+
chunk_in = torch.nn.functional.pad(chunk_t, (0, 0, 0, pad_len))
|
| 155 |
+
else:
|
| 156 |
+
chunk_in = chunk_t[:, :config['chunk_frames'], :]
|
| 157 |
+
|
| 158 |
+
# State preparation
|
| 159 |
+
curr_spk_len = state.spkcache.shape[1]
|
| 160 |
+
curr_fifo_len = state.fifo.shape[1]
|
| 161 |
+
|
| 162 |
+
current_spkcache = state.spkcache
|
| 163 |
+
if curr_spk_len < config['spkcache_input_len']:
|
| 164 |
+
current_spkcache = torch.nn.functional.pad(
|
| 165 |
+
current_spkcache, (0, 0, 0, config['spkcache_input_len'] - curr_spk_len)
|
| 166 |
+
)
|
| 167 |
+
elif curr_spk_len > config['spkcache_input_len']:
|
| 168 |
+
current_spkcache = current_spkcache[:, :config['spkcache_input_len'], :]
|
| 169 |
+
|
| 170 |
+
current_fifo = state.fifo
|
| 171 |
+
if curr_fifo_len < config['fifo_input_len']:
|
| 172 |
+
current_fifo = torch.nn.functional.pad(
|
| 173 |
+
current_fifo, (0, 0, 0, config['fifo_input_len'] - curr_fifo_len)
|
| 174 |
+
)
|
| 175 |
+
elif curr_fifo_len > config['fifo_input_len']:
|
| 176 |
+
current_fifo = current_fifo[:, :config['fifo_input_len'], :]
|
| 177 |
+
|
| 178 |
+
# CoreML inference
|
| 179 |
+
coreml_inputs = {
|
| 180 |
+
"chunk": chunk_in.numpy().astype(np.float32),
|
| 181 |
+
"chunk_lengths": np.array([actual_len], dtype=np.int32),
|
| 182 |
+
"spkcache": current_spkcache.numpy().astype(np.float32),
|
| 183 |
+
"spkcache_lengths": np.array([curr_spk_len], dtype=np.int32),
|
| 184 |
+
"fifo": current_fifo.numpy().astype(np.float32),
|
| 185 |
+
"fifo_lengths": np.array([curr_fifo_len], dtype=np.int32)
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
coreml_out = main_model.predict(coreml_inputs)
|
| 189 |
+
|
| 190 |
+
pred_logits = torch.from_numpy(coreml_out["speaker_preds"])
|
| 191 |
+
chunk_embs = torch.from_numpy(coreml_out["chunk_pre_encoder_embs"])
|
| 192 |
+
chunk_emb_len = int(coreml_out["chunk_pre_encoder_lengths"][0])
|
| 193 |
+
|
| 194 |
+
chunk_embs = chunk_embs[:, :chunk_emb_len, :]
|
| 195 |
+
|
| 196 |
+
lc = round(left_offset / subsampling_factor)
|
| 197 |
+
rc = math.ceil(right_offset / subsampling_factor)
|
| 198 |
+
|
| 199 |
+
state, chunk_probs = modules.streaming_update(
|
| 200 |
+
streaming_state=state,
|
| 201 |
+
chunk=chunk_embs,
|
| 202 |
+
preds=pred_logits,
|
| 203 |
+
lc=lc,
|
| 204 |
+
rc=rc
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
all_preds.append(chunk_probs)
|
| 208 |
+
stt_feat = end_feat
|
| 209 |
+
chunk_idx += 1
|
| 210 |
+
|
| 211 |
+
print(f"\r Diarization chunk {chunk_idx}", end='')
|
| 212 |
+
|
| 213 |
+
print()
|
| 214 |
+
|
| 215 |
+
if len(all_preds) > 0:
|
| 216 |
+
return torch.cat(all_preds, dim=1)
|
| 217 |
+
return None
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def run_reference(nemo_model, main_model, audio_path, config):
|
| 221 |
+
"""
|
| 222 |
+
Reference implementation using NeMo preprocessing.
|
| 223 |
+
"""
|
| 224 |
+
modules = nemo_model.sortformer_modules
|
| 225 |
+
subsampling_factor = modules.subsampling_factor
|
| 226 |
+
|
| 227 |
+
# Load full audio
|
| 228 |
+
full_audio, _ = librosa.load(audio_path, sr=config['sample_rate'], mono=True)
|
| 229 |
+
audio_tensor = torch.from_numpy(full_audio).unsqueeze(0).float()
|
| 230 |
+
audio_length = torch.tensor([len(full_audio)], dtype=torch.long)
|
| 231 |
+
|
| 232 |
+
# Extract features using NeMo preprocessor
|
| 233 |
+
with torch.no_grad():
|
| 234 |
+
processed_signal, processed_signal_length = nemo_model.process_signal(
|
| 235 |
+
audio_signal=audio_tensor, audio_signal_length=audio_length
|
| 236 |
+
)
|
| 237 |
+
processed_signal = processed_signal[:, :, :processed_signal_length.max()]
|
| 238 |
+
|
| 239 |
+
print(f"NeMo Preproc: features shape = {processed_signal.shape}")
|
| 240 |
+
|
| 241 |
+
# Streaming loop
|
| 242 |
+
state = modules.init_streaming_state(batch_size=1, device='cpu')
|
| 243 |
+
all_preds = []
|
| 244 |
+
|
| 245 |
+
feat_len = processed_signal.shape[2]
|
| 246 |
+
chunk_len = modules.chunk_len
|
| 247 |
+
left_ctx = modules.chunk_left_context
|
| 248 |
+
right_ctx = modules.chunk_right_context
|
| 249 |
+
|
| 250 |
+
stt_feat, end_feat, chunk_idx = 0, 0, 0
|
| 251 |
+
|
| 252 |
+
while end_feat < feat_len:
|
| 253 |
+
left_offset = min(left_ctx * subsampling_factor, stt_feat)
|
| 254 |
+
end_feat = min(stt_feat + chunk_len * subsampling_factor, feat_len)
|
| 255 |
+
right_offset = min(right_ctx * subsampling_factor, feat_len - end_feat)
|
| 256 |
+
|
| 257 |
+
chunk_feat = processed_signal[:, :, stt_feat - left_offset : end_feat + right_offset]
|
| 258 |
+
actual_len = chunk_feat.shape[2]
|
| 259 |
+
|
| 260 |
+
chunk_t = chunk_feat.transpose(1, 2)
|
| 261 |
+
|
| 262 |
+
if actual_len < config['chunk_frames']:
|
| 263 |
+
pad_len = config['chunk_frames'] - actual_len
|
| 264 |
+
chunk_in = torch.nn.functional.pad(chunk_t, (0, 0, 0, pad_len))
|
| 265 |
+
else:
|
| 266 |
+
chunk_in = chunk_t[:, :config['chunk_frames'], :]
|
| 267 |
+
|
| 268 |
+
curr_spk_len = state.spkcache.shape[1]
|
| 269 |
+
curr_fifo_len = state.fifo.shape[1]
|
| 270 |
+
|
| 271 |
+
current_spkcache = state.spkcache
|
| 272 |
+
if curr_spk_len < config['spkcache_input_len']:
|
| 273 |
+
current_spkcache = torch.nn.functional.pad(
|
| 274 |
+
current_spkcache, (0, 0, 0, config['spkcache_input_len'] - curr_spk_len)
|
| 275 |
+
)
|
| 276 |
+
elif curr_spk_len > config['spkcache_input_len']:
|
| 277 |
+
current_spkcache = current_spkcache[:, :config['spkcache_input_len'], :]
|
| 278 |
+
|
| 279 |
+
current_fifo = state.fifo
|
| 280 |
+
if curr_fifo_len < config['fifo_input_len']:
|
| 281 |
+
current_fifo = torch.nn.functional.pad(
|
| 282 |
+
current_fifo, (0, 0, 0, config['fifo_input_len'] - curr_fifo_len)
|
| 283 |
+
)
|
| 284 |
+
elif curr_fifo_len > config['fifo_input_len']:
|
| 285 |
+
current_fifo = current_fifo[:, :config['fifo_input_len'], :]
|
| 286 |
+
|
| 287 |
+
coreml_inputs = {
|
| 288 |
+
"chunk": chunk_in.numpy().astype(np.float32),
|
| 289 |
+
"chunk_lengths": np.array([actual_len], dtype=np.int32),
|
| 290 |
+
"spkcache": current_spkcache.numpy().astype(np.float32),
|
| 291 |
+
"spkcache_lengths": np.array([curr_spk_len], dtype=np.int32),
|
| 292 |
+
"fifo": current_fifo.numpy().astype(np.float32),
|
| 293 |
+
"fifo_lengths": np.array([curr_fifo_len], dtype=np.int32)
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
coreml_out = main_model.predict(coreml_inputs)
|
| 297 |
+
|
| 298 |
+
pred_logits = torch.from_numpy(coreml_out["speaker_preds"])
|
| 299 |
+
chunk_embs = torch.from_numpy(coreml_out["chunk_pre_encoder_embs"])
|
| 300 |
+
chunk_emb_len = int(coreml_out["chunk_pre_encoder_lengths"][0])
|
| 301 |
+
|
| 302 |
+
chunk_embs = chunk_embs[:, :chunk_emb_len, :]
|
| 303 |
+
|
| 304 |
+
lc = round(left_offset / subsampling_factor)
|
| 305 |
+
rc = math.ceil(right_offset / subsampling_factor)
|
| 306 |
+
|
| 307 |
+
state, chunk_probs = modules.streaming_update(
|
| 308 |
+
streaming_state=state,
|
| 309 |
+
chunk=chunk_embs,
|
| 310 |
+
preds=pred_logits,
|
| 311 |
+
lc=lc,
|
| 312 |
+
rc=rc
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
all_preds.append(chunk_probs)
|
| 316 |
+
stt_feat = end_feat
|
| 317 |
+
chunk_idx += 1
|
| 318 |
+
|
| 319 |
+
if len(all_preds) > 0:
|
| 320 |
+
return torch.cat(all_preds, dim=1)
|
| 321 |
+
return None
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def validate(model_name, coreml_dir, audio_path):
|
| 325 |
+
"""
|
| 326 |
+
Validate true streaming against NeMo preprocessing.
|
| 327 |
+
"""
|
| 328 |
+
print("=" * 70)
|
| 329 |
+
print("VALIDATION: True Streaming vs NeMo Preprocessing")
|
| 330 |
+
print("=" * 70)
|
| 331 |
+
|
| 332 |
+
# Load NeMo model
|
| 333 |
+
print(f"\nLoading NeMo Model: {model_name}")
|
| 334 |
+
nemo_model = SortformerEncLabelModel.from_pretrained(model_name, map_location="cpu")
|
| 335 |
+
nemo_model.eval()
|
| 336 |
+
|
| 337 |
+
# Apply config
|
| 338 |
+
modules = nemo_model.sortformer_modules
|
| 339 |
+
modules.chunk_len = CONFIG['chunk_len']
|
| 340 |
+
modules.chunk_right_context = CONFIG['chunk_right_context']
|
| 341 |
+
modules.chunk_left_context = CONFIG['chunk_left_context']
|
| 342 |
+
modules.fifo_len = CONFIG['fifo_len']
|
| 343 |
+
modules.spkcache_len = CONFIG['spkcache_len']
|
| 344 |
+
modules.spkcache_update_period = CONFIG['spkcache_update_period']
|
| 345 |
+
|
| 346 |
+
# Disable dither and pad_to
|
| 347 |
+
if hasattr(nemo_model.preprocessor, 'featurizer'):
|
| 348 |
+
nemo_model.preprocessor.featurizer.dither = 0.0
|
| 349 |
+
nemo_model.preprocessor.featurizer.pad_to = 0
|
| 350 |
+
|
| 351 |
+
print(f"Config: chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, "
|
| 352 |
+
f"right_ctx={modules.chunk_right_context}")
|
| 353 |
+
|
| 354 |
+
# Load CoreML models
|
| 355 |
+
print(f"Loading CoreML Models from {coreml_dir}...")
|
| 356 |
+
preproc_model = ct.models.MLModel(
|
| 357 |
+
os.path.join(coreml_dir, "SortformerPreprocessor.mlpackage"),
|
| 358 |
+
compute_units=ct.ComputeUnit.CPU_ONLY
|
| 359 |
+
)
|
| 360 |
+
main_model = ct.models.MLModel(
|
| 361 |
+
os.path.join(coreml_dir, "Sortformer16.mlpackage"),
|
| 362 |
+
compute_units=ct.ComputeUnit.CPU_ONLY
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Reference
|
| 366 |
+
print("\n" + "=" * 70)
|
| 367 |
+
print("TEST 1: NeMo Preprocessing + CoreML Inference (Reference)")
|
| 368 |
+
print("=" * 70)
|
| 369 |
+
|
| 370 |
+
ref_probs = run_reference(nemo_model, main_model, audio_path, CONFIG)
|
| 371 |
+
if ref_probs is not None:
|
| 372 |
+
ref_probs_np = ref_probs.squeeze(0).detach().cpu().numpy()
|
| 373 |
+
print(f"Reference Probs Shape: {ref_probs_np.shape}")
|
| 374 |
+
else:
|
| 375 |
+
print("Reference inference failed!")
|
| 376 |
+
return
|
| 377 |
+
|
| 378 |
+
# True streaming
|
| 379 |
+
print("\n" + "=" * 70)
|
| 380 |
+
print("TEST 2: True Streaming (Audio → CoreML Preproc → CoreML Main)")
|
| 381 |
+
print("=" * 70)
|
| 382 |
+
|
| 383 |
+
streaming_probs = run_true_streaming(nemo_model, preproc_model, main_model, audio_path, CONFIG)
|
| 384 |
+
|
| 385 |
+
if streaming_probs is not None:
|
| 386 |
+
streaming_probs_np = streaming_probs.squeeze(0).detach().cpu().numpy()
|
| 387 |
+
print(f"Streaming Probs Shape: {streaming_probs_np.shape}")
|
| 388 |
+
|
| 389 |
+
# Compare
|
| 390 |
+
min_len = min(ref_probs_np.shape[0], streaming_probs_np.shape[0])
|
| 391 |
+
diff = np.abs(ref_probs_np[:min_len] - streaming_probs_np[:min_len])
|
| 392 |
+
print(f"\nLength: ref={ref_probs_np.shape[0]}, streaming={streaming_probs_np.shape[0]}")
|
| 393 |
+
print(f"Mean Absolute Error: {np.mean(diff):.8f}")
|
| 394 |
+
print(f"Max Absolute Error: {np.max(diff):.8f}")
|
| 395 |
+
|
| 396 |
+
if np.max(diff) < 0.01:
|
| 397 |
+
print("\n✅ SUCCESS: True streaming matches reference!")
|
| 398 |
+
else:
|
| 399 |
+
print("\n⚠️ Errors exceed tolerance")
|
| 400 |
+
else:
|
| 401 |
+
print("True streaming inference produced no output!")
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
if __name__ == "__main__":
|
| 405 |
+
parser = argparse.ArgumentParser()
|
| 406 |
+
parser.add_argument("--model_name", default="nvidia/diar_streaming_sortformer_4spk-v2.1")
|
| 407 |
+
parser.add_argument("--coreml_dir", default="coreml_models")
|
| 408 |
+
parser.add_argument("--audio_path", default="audio.wav")
|
| 409 |
+
args = parser.parse_args()
|
| 410 |
+
|
| 411 |
+
validate(args.model_name, args.coreml_dir, args.audio_path)
|