alexwengg commited on
Commit
ed33fd7
·
verified ·
1 Parent(s): 4f2471f

Upload 33 files

Browse files
convert.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.patches as patches
4
+ import matplotlib
5
+ import seaborn as sns
6
+ import numpy as np
7
+ import threading
8
+ import onnx2torch
9
+ import onnxscript
10
+ from nemo.collections.asr.models import SortformerEncLabelModel
11
+ from pydub import AudioSegment
12
+ import coremltools as ct
13
+ from pydub.playback import play as play_audio
14
+
15
+ # --- 1. Setup & Config ---
16
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
17
+ audio_file = "audio.wav"
18
+
19
+ # Load Audio for Playback (pydub uses milliseconds)
20
+ print("Loading audio file for playback...")
21
+ full_audio = AudioSegment.from_wav(audio_file)
22
+
23
+ # --- 2. Load Model ---
24
+ model = SortformerEncLabelModel.from_pretrained(
25
+ "nvidia/diar_streaming_sortformer_4spk-v2.1",
26
+ map_location=device
27
+ )
28
+ model.eval()
29
+ model.to(device)
30
+
31
+ print(model.output_names)
32
+
33
+ def streaming_input_examples(self):
34
+ """Input tensor examples for exporting streaming version of model"""
35
+ batch_size = 4
36
+ feat_in = self.cfg.get("preprocessor", {}).get("features", 128)
37
+ chunk = torch.rand([batch_size, 120, feat_in]).to(self.device)
38
+ chunk_lengths = torch.tensor([120] * batch_size).to(self.device)
39
+ spkcache = torch.randn([batch_size, 188, 512]).to(self.device)
40
+ spkcache_lengths = torch.tensor([40, 188, 0, 68]).to(self.device)
41
+ fifo = torch.randn([batch_size, 188, 512]).to(self.device)
42
+ fifo_lengths = torch.tensor([50, 88, 0, 90]).to(self.device)
43
+ return chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths
44
+
45
+
46
+ inputs = streaming_input_examples(model)
47
+
48
+ export_out = model.export("streaming-sortformer.onnx", input_example=inputs)
49
+ scripted_model = onnx2torch.convert('streaming-sortformer.onnx')
50
+
51
+ BATCH_SIZE = 4
52
+ CHUNK_LEN = 120
53
+ FEAT_DIM = 128
54
+ CACHE_LEN = 188
55
+ EMBED_DIM = 512
56
+
57
+ ct_inputs = [
58
+ ct.TensorType(name="chunk", shape=(BATCH_SIZE, CHUNK_LEN, FEAT_DIM)),
59
+ ct.TensorType(name="chunk_lens", shape=(BATCH_SIZE,)),
60
+ ct.TensorType(name="spkcache", shape=(BATCH_SIZE, CACHE_LEN, EMBED_DIM)),
61
+ ct.TensorType(name="spkcache_lens", shape=(BATCH_SIZE,)),
62
+ ct.TensorType(name="fifo", shape=(BATCH_SIZE, CACHE_LEN, EMBED_DIM)),
63
+ ct.TensorType(name="fifo_lens", shape=(BATCH_SIZE,)),
64
+ ]
65
+
66
+ ct_outputs = [
67
+ ct.TensorType(name="preds"),
68
+ ct.TensorType(name="new_spkcache"),
69
+ ct.TensorType(name="new_spkcache_lens"),
70
+ ct.TensorType(name="new_fifo"),
71
+ ct.TensorType(name="new_fifo_lens"),
72
+ ]
73
+
74
+
75
+ ct.convert(
76
+ scripted_model,
77
+ inputs=ct_inputs,
78
+ outputs=ct_outputs,
79
+ convert_to="mlprogram",
80
+ minimum_deployment_target=ct.target.iOS17,
81
+ compute_precision=ct.precision.FLOAT16,
82
+ )
convert_dynamic.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Convert Sortformer to CoreML with proper dynamic length handling.
4
+
5
+ The key issue: Original conversion traced with fixed lengths (spkcache=120, fifo=40),
6
+ but at runtime we need to handle empty state (spkcache=0, fifo=0) for first chunk.
7
+
8
+ Solution: Use scripting instead of tracing, or trace with multiple example lengths.
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import coremltools as ct
14
+ import numpy as np
15
+ import os
16
+ import sys
17
+
18
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
19
+ sys.path.insert(0, os.path.join(SCRIPT_DIR, 'NeMo'))
20
+
21
+ from nemo.collections.asr.models import SortformerEncLabelModel
22
+
23
+ print("=" * 70)
24
+ print("CONVERTING SORTFORMER WITH DYNAMIC LENGTH SUPPORT")
25
+ print("=" * 70)
26
+
27
+ # Load model
28
+ model_path = os.path.join(SCRIPT_DIR, 'diar_streaming_sortformer_4spk-v2.nemo')
29
+ print(f"Loading model: {model_path}")
30
+ model = SortformerEncLabelModel.restore_from(model_path, map_location='cpu', strict=False)
31
+ model.eval()
32
+
33
+ # Configure for low-latency streaming
34
+ modules = model.sortformer_modules
35
+ modules.chunk_len = 6
36
+ modules.chunk_left_context = 1
37
+ modules.chunk_right_context = 1
38
+ modules.fifo_len = 40
39
+ modules.spkcache_len = 120
40
+ modules.spkcache_update_period = 30
41
+
42
+ print(f"Config: chunk_len={modules.chunk_len}, left={modules.chunk_left_context}, right={modules.chunk_right_context}")
43
+ print(f" fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}")
44
+
45
+ # Dimensions
46
+ chunk_frames = (modules.chunk_len + modules.chunk_left_context + modules.chunk_right_context) * modules.subsampling_factor
47
+ fc_d_model = modules.fc_d_model # 512
48
+ feat_dim = 128
49
+
50
+ print(f"Chunk frames: {chunk_frames}")
51
+
52
+ class DynamicPreEncoderWrapper(nn.Module):
53
+ """Pre-encoder that properly handles dynamic lengths."""
54
+
55
+ def __init__(self, model, max_spkcache=120, max_fifo=40, max_chunk=8):
56
+ super().__init__()
57
+ self.model = model
58
+ self.max_spkcache = max_spkcache
59
+ self.max_fifo = max_fifo
60
+ self.max_chunk = max_chunk
61
+ self.max_total = max_spkcache + max_fifo + max_chunk
62
+
63
+ def forward(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths):
64
+ # Pre-encode the chunk
65
+ chunk_embs, chunk_emb_lengths = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths)
66
+
67
+ # Get actual lengths as scalars
68
+ spk_len = spkcache_lengths[0].item() if spkcache_lengths.numel() > 0 else 0
69
+ fifo_len = fifo_lengths[0].item() if fifo_lengths.numel() > 0 else 0
70
+ chunk_len = chunk_emb_lengths[0].item()
71
+ total_len = spk_len + fifo_len + chunk_len
72
+
73
+ # Create output tensor (packed at start, rest is zeros)
74
+ B, _, D = spkcache.shape
75
+ output = torch.zeros(B, self.max_total, D, device=chunk.device, dtype=chunk.dtype)
76
+
77
+ # Copy valid frames
78
+ if spk_len > 0:
79
+ output[:, :spk_len, :] = spkcache[:, :spk_len, :]
80
+ if fifo_len > 0:
81
+ output[:, spk_len:spk_len+fifo_len, :] = fifo[:, :fifo_len, :]
82
+ output[:, spk_len+fifo_len:spk_len+fifo_len+chunk_len, :] = chunk_embs[:, :chunk_len, :]
83
+
84
+ total_length = torch.tensor([total_len], dtype=torch.long)
85
+
86
+ return output, total_length, chunk_embs, chunk_emb_lengths
87
+
88
+
89
+ class DynamicHeadWrapper(nn.Module):
90
+ """Head that properly handles dynamic lengths with masking."""
91
+
92
+ def __init__(self, model):
93
+ super().__init__()
94
+ self.model = model
95
+
96
+ def forward(self, pre_encoder_embs, pre_encoder_lengths, chunk_embs, chunk_emb_lengths):
97
+ # Encode
98
+ fc_embs, fc_lengths = self.model.frontend_encoder(
99
+ processed_signal=pre_encoder_embs,
100
+ processed_signal_length=pre_encoder_lengths,
101
+ bypass_pre_encode=True,
102
+ )
103
+
104
+ # Get predictions
105
+ preds = self.model.forward_infer(fc_embs, fc_lengths)
106
+
107
+ # Apply mask based on actual length
108
+ # preds shape: [B, T, num_speakers]
109
+ max_len = preds.shape[1]
110
+ length = pre_encoder_lengths[0]
111
+ mask = torch.arange(max_len, device=preds.device) < length
112
+ preds = preds * mask.unsqueeze(0).unsqueeze(-1).float()
113
+
114
+ return preds, chunk_embs, chunk_emb_lengths
115
+
116
+
117
+ # Test with both empty and full state
118
+ print("\n" + "=" * 70)
119
+ print("TESTING DYNAMIC WRAPPERS")
120
+ print("=" * 70)
121
+
122
+ pre_encoder = DynamicPreEncoderWrapper(model)
123
+ head = DynamicHeadWrapper(model)
124
+ pre_encoder.eval()
125
+ head.eval()
126
+
127
+ # Test 1: Empty state (like chunk 0)
128
+ print("\nTest 1: Empty state (chunk 0)")
129
+ chunk = torch.randn(1, 56, 128) # First chunk has fewer frames
130
+ chunk_len = torch.tensor([56], dtype=torch.long)
131
+ spkcache = torch.zeros(1, 120, 512)
132
+ spkcache_len = torch.tensor([0], dtype=torch.long)
133
+ fifo = torch.zeros(1, 40, 512)
134
+ fifo_len = torch.tensor([0], dtype=torch.long)
135
+
136
+ with torch.no_grad():
137
+ pre_out, pre_len, chunk_embs, chunk_emb_len = pre_encoder(
138
+ chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len
139
+ )
140
+ preds, _, _ = head(pre_out, pre_len, chunk_embs, chunk_emb_len)
141
+
142
+ print(f" Pre-encoder output: {pre_out.shape}, length={pre_len.item()}")
143
+ print(f" Chunk embeddings: {chunk_embs.shape}, length={chunk_emb_len.item()}")
144
+ print(f" Predictions: {preds.shape}")
145
+ sums = [f"{preds[0, i, :].sum().item():.4f}" for i in range(min(8, preds.shape[1]))]
146
+ print(f" First 8 pred frames sum: {sums}")
147
+
148
+ # Test 2: Full state
149
+ print("\nTest 2: Full state")
150
+ chunk = torch.randn(1, 64, 128)
151
+ chunk_len = torch.tensor([64], dtype=torch.long)
152
+ spkcache = torch.randn(1, 120, 512)
153
+ spkcache_len = torch.tensor([120], dtype=torch.long)
154
+ fifo = torch.randn(1, 40, 512)
155
+ fifo_len = torch.tensor([40], dtype=torch.long)
156
+
157
+ with torch.no_grad():
158
+ pre_out, pre_len, chunk_embs, chunk_emb_len = pre_encoder(
159
+ chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len
160
+ )
161
+ preds, _, _ = head(pre_out, pre_len, chunk_embs, chunk_emb_len)
162
+
163
+ print(f" Pre-encoder output: {pre_out.shape}, length={pre_len.item()}")
164
+ print(f" Chunk embeddings: {chunk_embs.shape}, length={chunk_emb_len.item()}")
165
+ print(f" Predictions: {preds.shape}")
166
+
167
+ print("\n" + "=" * 70)
168
+ print("ISSUE IDENTIFIED")
169
+ print("=" * 70)
170
+ print("""
171
+ The problem is that the current CoreML model was traced with FIXED lengths.
172
+ When lengths change at runtime, the traced operations don't adapt.
173
+
174
+ The fix requires re-tracing with proper dynamic handling OR using coremltools
175
+ flexible shapes feature.
176
+
177
+ For now, let's try a simpler approach: always pad inputs to max size and
178
+ use the length parameters only for extracting the correct output slice.
179
+ """)
180
+
181
+ # The issue is that torch.jit.trace captures specific tensor values
182
+ # We need to use torch.jit.script for truly dynamic behavior
183
+ # But many NeMo operations don't work with script
184
+
185
+ print("\nATTEMPTING CONVERSION WITH FLEXIBLE SHAPES...")
186
+
187
+ # Try using coremltools range shapes
188
+ try:
189
+ # Create wrapper that handles the length masking internally
190
+ class SimplePipelineWrapper(nn.Module):
191
+ def __init__(self, model):
192
+ super().__init__()
193
+ self.model = model
194
+
195
+ def forward(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths):
196
+ # Pre-encode chunk
197
+ chunk_embs, chunk_emb_lens = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths)
198
+
199
+ # Get lengths
200
+ spk_len = spkcache_lengths[0]
201
+ fifo_len = fifo_lengths[0]
202
+ chunk_len = chunk_emb_lens[0]
203
+
204
+ # Concatenate (always use fixed output size, rely on length for valid region)
205
+ # This matches what NeMo does internally
206
+ B = chunk.shape[0]
207
+ max_out = 168 # 120 + 40 + 8
208
+ D = 512
209
+
210
+ concat_embs = torch.zeros(B, max_out, D, device=chunk.device, dtype=chunk.dtype)
211
+
212
+ # Copy spkcache
213
+ for i in range(120):
214
+ if i < spk_len:
215
+ concat_embs[:, i, :] = spkcache[:, i, :]
216
+
217
+ # Copy fifo
218
+ for i in range(40):
219
+ if i < fifo_len:
220
+ concat_embs[:, 120 + i, :] = fifo[:, i, :]
221
+
222
+ # Copy chunk embeddings
223
+ for i in range(8):
224
+ if i < chunk_len:
225
+ concat_embs[:, 120 + 40 + i, :] = chunk_embs[:, i, :]
226
+
227
+ total_len = spk_len + fifo_len + chunk_len
228
+ total_lens = total_len.unsqueeze(0)
229
+
230
+ # Run through encoder
231
+ fc_embs, fc_lens = self.model.frontend_encoder(
232
+ processed_signal=concat_embs,
233
+ processed_signal_length=total_lens,
234
+ bypass_pre_encode=True,
235
+ )
236
+
237
+ # Get predictions
238
+ preds = self.model.forward_infer(fc_embs, fc_lens)
239
+
240
+ return preds, chunk_embs, chunk_emb_lens
241
+
242
+ wrapper = SimplePipelineWrapper(model)
243
+ wrapper.eval()
244
+
245
+ # Trace with empty state example
246
+ print("Tracing with empty state example...")
247
+ chunk = torch.randn(1, 64, 128)
248
+ chunk_len = torch.tensor([56], dtype=torch.long) # Actual length
249
+ spkcache = torch.zeros(1, 120, 512)
250
+ spkcache_len = torch.tensor([0], dtype=torch.long)
251
+ fifo = torch.zeros(1, 40, 512)
252
+ fifo_len = torch.tensor([0], dtype=torch.long)
253
+
254
+ with torch.no_grad():
255
+ traced = torch.jit.trace(wrapper, (chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len))
256
+
257
+ print("Converting to CoreML...")
258
+ mlmodel = ct.convert(
259
+ traced,
260
+ inputs=[
261
+ ct.TensorType(name="chunk", shape=(1, 64, 128), dtype=np.float32),
262
+ ct.TensorType(name="chunk_lengths", shape=(1,), dtype=np.int32),
263
+ ct.TensorType(name="spkcache", shape=(1, 120, 512), dtype=np.float32),
264
+ ct.TensorType(name="spkcache_lengths", shape=(1,), dtype=np.int32),
265
+ ct.TensorType(name="fifo", shape=(1, 40, 512), dtype=np.float32),
266
+ ct.TensorType(name="fifo_lengths", shape=(1,), dtype=np.int32),
267
+ ],
268
+ outputs=[
269
+ ct.TensorType(name="speaker_preds", dtype=np.float32),
270
+ ct.TensorType(name="chunk_pre_encoder_embs", dtype=np.float32),
271
+ ct.TensorType(name="chunk_pre_encoder_lengths", dtype=np.int32),
272
+ ],
273
+ minimum_deployment_target=ct.target.iOS16,
274
+ compute_precision=ct.precision.FLOAT32,
275
+ compute_units=ct.ComputeUnit.CPU_ONLY, # Start with CPU for debugging
276
+ )
277
+
278
+ output_path = os.path.join(SCRIPT_DIR, 'coreml_models', 'SortformerPipeline_Dynamic.mlpackage')
279
+ mlmodel.save(output_path)
280
+ print(f"Saved to: {output_path}")
281
+
282
+ # Test the new model
283
+ print("\nTesting new CoreML model...")
284
+ test_output = mlmodel.predict({
285
+ 'chunk': chunk.numpy(),
286
+ 'chunk_lengths': chunk_len.numpy().astype(np.int32),
287
+ 'spkcache': spkcache.numpy(),
288
+ 'spkcache_lengths': spkcache_len.numpy().astype(np.int32),
289
+ 'fifo': fifo.numpy(),
290
+ 'fifo_lengths': fifo_len.numpy().astype(np.int32),
291
+ })
292
+
293
+ coreml_preds = np.array(test_output['speaker_preds'])
294
+ print(f"CoreML predictions shape: {coreml_preds.shape}")
295
+ print(f"CoreML first 8 frames:")
296
+ for i in range(min(8, coreml_preds.shape[1])):
297
+ print(f" Frame {i}: {coreml_preds[0, i, :]}")
298
+
299
+ except Exception as e:
300
+ print(f"Error during conversion: {e}")
301
+ import traceback
302
+ traceback.print_exc()
coreml_wrappers.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from safe_concat import *
4
+ from nemo.collections.asr.models import SortformerEncLabelModel
5
+
6
+
7
+ def fixed_concat_and_pad(embs, lengths, max_total_len=188+188+6):
8
+ """
9
+ ANE-safe concat and pad that avoids zero-length slices.
10
+
11
+ Uses gather with arithmetic-computed indices to pack valid frames efficiently.
12
+
13
+ Args:
14
+ embs: List of 3 tensors [spkcache, fifo, chunk], each (B, seq_len, D)
15
+ lengths: List of 3 length tensors, each (1,) or scalar
16
+ First two may be 0, third is always > 0
17
+ max_total_len: Output sequence length (padded with zeros)
18
+
19
+ Returns:
20
+ output: (B, max_total_len, D) with valid frames packed at the start
21
+ total_length: sum of lengths
22
+ """
23
+ B, _, D = embs[0].shape
24
+ device = embs[0].device
25
+
26
+ # Fixed sizes (known at trace time, becomes constants in graph)
27
+ size0, size1, size2 = embs[0].shape[1], embs[1].shape[1], embs[2].shape[1]
28
+ total_input_size = size0 + size1 + size2
29
+
30
+ # Concatenate all embeddings at full size (no zero-length slices!)
31
+ full_concat = torch.cat(embs, dim=1) # (B, total_input_size, D)
32
+
33
+ # Get lengths (reshape to scalar for efficient broadcast)
34
+ len0 = lengths[0].reshape(())
35
+ len1 = lengths[1].reshape(())
36
+ len2 = lengths[2].reshape(())
37
+ total_length = len0 + len1 + len2
38
+
39
+ # Output positions: [0, 1, 2, ..., max_total_len-1]
40
+ out_pos = torch.arange(max_total_len, device=device, dtype=torch.long)
41
+
42
+ # Compute gather indices using arithmetic (more efficient than multiple where())
43
+ #
44
+ # For output position p:
45
+ # seg0 (p < len0): index = p
46
+ # seg1 (len0 <= p < len0+len1): index = (p - len0) + size0 = p + (size0 - len0)
47
+ # seg2 (len0+len1 <= p < total): index = (p - len0 - len1) + size0 + size1
48
+ # = p + (size0 + size1 - len0 - len1)
49
+ #
50
+ # This simplifies to: index = p + offset, where offset depends on segment.
51
+ # offset_seg0 = 0
52
+ # offset_seg1 = size0 - len0
53
+ # offset_seg2 = size0 + size1 - len0 - len1 = offset_seg1 + (size1 - len1)
54
+ #
55
+ # Using segment indicators (0 or 1):
56
+ # offset = in_seg1_or_2 * (size0 - len0) + in_seg2 * (size1 - len1)
57
+
58
+ cumsum0 = len0
59
+ cumsum1 = len0 + len1
60
+
61
+ # Segment indicators (bool -> long for arithmetic)
62
+ in_seg1_or_2 = (out_pos >= cumsum0).long() # 1 if in seg1 or seg2
63
+ in_seg2 = (out_pos >= cumsum1).long() # 1 if in seg2
64
+
65
+ # Compute offset and gather index
66
+ offset = in_seg1_or_2 * (size0 - len0) + in_seg2 * (size1 - len1)
67
+ gather_idx = (out_pos + offset).clamp(0, total_input_size - 1)
68
+
69
+ # Expand for gather: (B, max_total_len, D)
70
+ gather_idx = gather_idx.unsqueeze(0).unsqueeze(-1).expand(B, max_total_len, D)
71
+
72
+ # Gather and mask padding
73
+ output = torch.gather(full_concat, dim=1, index=gather_idx)
74
+ output = output * (out_pos < total_length).float().unsqueeze(0).unsqueeze(-1)
75
+
76
+ return output, total_length
77
+
78
+
79
+ class PreprocessorWrapper(nn.Module):
80
+ """
81
+ Wraps the NeMo preprocessor (FilterbankFeaturesTA) for CoreML export.
82
+ We need to ensure it takes (audio, length) and returns (features, length).
83
+ """
84
+
85
+ def __init__(self, preprocessor):
86
+ super().__init__()
87
+ self.preprocessor = preprocessor
88
+
89
+ def forward(self, audio_signal, length):
90
+ # NeMo preprocessor returns (features, length)
91
+ # features shape: [B, D, T]
92
+ return self.preprocessor(input_signal=audio_signal, length=length)
93
+
94
+
95
+ class SortformerHeadWrapper(nn.Module):
96
+ def __init__(self, model):
97
+ super().__init__()
98
+ self.model = model
99
+
100
+ def forward(self, pre_encoder_embs, pre_encoder_lengths, chunk_pre_encoder_embs, chunk_pre_encoder_lengths):
101
+ spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths = self.model.frontend_encoder(
102
+ processed_signal=pre_encoder_embs,
103
+ processed_signal_length=pre_encoder_lengths,
104
+ bypass_pre_encode=True,
105
+ )
106
+
107
+ # forward pass for inference
108
+ spkcache_fifo_chunk_preds = self.model.forward_infer(
109
+ spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths
110
+ )
111
+ return spkcache_fifo_chunk_preds, chunk_pre_encoder_embs, chunk_pre_encoder_lengths
112
+
113
+
114
+ class SortformerCoreMLWrapper(nn.Module):
115
+ """
116
+ Wraps the entire Sortformer pipeline (Encoder + Streaming Logic for Export)
117
+ The 'forward_for_export' method in the model is the target.
118
+ """
119
+
120
+ def __init__(self, model):
121
+ super().__init__()
122
+ self.model = model
123
+ self.pre_encoder = PreEncoderWrapper(model)
124
+
125
+ def forward(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths):
126
+ (spkcache_fifo_chunk_pre_encode_embs, spkcache_fifo_chunk_pre_encode_lengths,
127
+ chunk_pre_encode_embs, chunk_pre_encode_lengths) = self.pre_encoder(
128
+ chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths
129
+ )
130
+
131
+ # encode the concatenated embeddings
132
+ spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths = self.model.frontend_encoder(
133
+ processed_signal=spkcache_fifo_chunk_pre_encode_embs,
134
+ processed_signal_length=spkcache_fifo_chunk_pre_encode_lengths,
135
+ bypass_pre_encode=True,
136
+ )
137
+
138
+ # forward pass for inference
139
+ spkcache_fifo_chunk_preds = self.model.forward_infer(
140
+ spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths
141
+ )
142
+ return spkcache_fifo_chunk_preds, chunk_pre_encode_embs, chunk_pre_encode_lengths
143
+
144
+
145
+ class PreEncoderWrapper(nn.Module):
146
+ """
147
+ Wraps the entire Sortformer pipeline (Encoder + Streaming Logic for Export)
148
+ The 'forward_for_export' method in the model is the target.
149
+ """
150
+
151
+ def __init__(self, model):
152
+ super().__init__()
153
+ self.model = model
154
+ modules = model.sortformer_modules
155
+ chunk_length = modules.chunk_left_context + modules.chunk_len + modules.chunk_right_context
156
+ self.pre_encoder_length = modules.spkcache_len + modules.fifo_len + chunk_length
157
+
158
+ def forward(self, *args):
159
+ if len(args) == 6:
160
+ return self.forward_concat(*args)
161
+ else:
162
+ return self.forward_pre_encode(*args)
163
+
164
+ def forward_concat(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths):
165
+ chunk_pre_encode_embs, chunk_pre_encode_lengths = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths)
166
+ chunk_pre_encode_lengths = chunk_pre_encode_lengths.to(torch.int64)
167
+ spkcache_fifo_chunk_pre_encode_embs, spkcache_fifo_chunk_pre_encode_lengths = fixed_concat_and_pad(
168
+ [spkcache, fifo, chunk_pre_encode_embs],
169
+ [spkcache_lengths, fifo_lengths, chunk_pre_encode_lengths],
170
+ self.pre_encoder_length
171
+ )
172
+ return (spkcache_fifo_chunk_pre_encode_embs, spkcache_fifo_chunk_pre_encode_lengths,
173
+ chunk_pre_encode_embs, chunk_pre_encode_lengths)
174
+
175
+ def forward_pre_encode(self, chunk, chunk_lengths):
176
+ chunk_pre_encode_embs, chunk_pre_encode_lengths = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths)
177
+ chunk_pre_encode_lengths = chunk_pre_encode_lengths.to(torch.int64)
178
+
179
+ return chunk_pre_encode_embs, chunk_pre_encode_lengths
180
+
181
+
182
+ class ConformerEncoderWrapper(nn.Module):
183
+ """
184
+ Wraps the entire Sortformer pipeline (Encoder + Streaming Logic for Export)
185
+ The 'forward_for_export' method in the model is the target.
186
+ """
187
+
188
+ def __init__(self, model):
189
+ super().__init__()
190
+ self.model = model
191
+
192
+ def forward(self, pre_encode_embs, pre_encode_lengths):
193
+ spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths = self.model.frontend_encoder(
194
+ processed_signal=pre_encode_embs,
195
+ processed_signal_length=pre_encode_lengths,
196
+ bypass_pre_encode=True,
197
+ )
198
+ return spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths
199
+
200
+
201
+ class SortformerEncoderWrapper(nn.Module):
202
+ """
203
+ Wraps the entire Sortformer pipeline (Encoder + Streaming Logic for Export)
204
+ The 'forward_for_export' method in the model is the target.
205
+ """
206
+
207
+ def __init__(self, model):
208
+ super().__init__()
209
+ self.model = model
210
+
211
+ def forward(self, encoder_embs, encoder_lengths):
212
+ spkcache_fifo_chunk_preds = self.model.forward_infer(
213
+ encoder_embs, encoder_lengths
214
+ )
215
+ return spkcache_fifo_chunk_preds
mic_inference.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Real-Time Microphone Diarization with CoreML
3
+
4
+ This script captures audio from the microphone in real-time,
5
+ processes it through CoreML models, and displays a live updating
6
+ diarization heatmap.
7
+
8
+ Pipeline: Microphone → Audio Buffer → CoreML Preproc → CoreML Main → Live Plot
9
+
10
+ Requirements:
11
+ pip install pyaudio matplotlib seaborn numpy coremltools
12
+
13
+ Usage:
14
+ python mic_inference.py
15
+ """
16
+ import os
17
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
18
+
19
+ import torch
20
+ import numpy as np
21
+ import coremltools as ct
22
+ import matplotlib.pyplot as plt
23
+ import matplotlib
24
+ matplotlib.use('TkAgg')
25
+ import seaborn as sns
26
+ import threading
27
+ import queue
28
+ import time
29
+ import math
30
+ import argparse
31
+
32
+ # Import NeMo for state management
33
+ from nemo.collections.asr.models import SortformerEncLabelModel
34
+
35
+ try:
36
+ import sounddevice as sd
37
+ SOUNDDEVICE_AVAILABLE = True
38
+ except ImportError:
39
+ import sounddevice as sd
40
+ SOUNDDEVICE_AVAILABLE = False
41
+ print("Warning: sounddevice not available. Install with: pip install sounddevice")
42
+
43
+
44
+ # ============================================================
45
+ # Configuration
46
+ # ============================================================
47
+ CONFIG = {
48
+ 'chunk_len': 6,
49
+ 'chunk_right_context': 1,
50
+ 'chunk_left_context': 1,
51
+ 'fifo_len': 40,
52
+ 'spkcache_len': 120,
53
+ 'spkcache_update_period': 32,
54
+ 'subsampling_factor': 8,
55
+ 'sample_rate': 16000,
56
+ 'mel_window': 400,
57
+ 'mel_stride': 160,
58
+
59
+ # Audio settings
60
+ 'audio_chunk_samples': 1280, # 80ms chunks from mic
61
+ 'channels': 1,
62
+ }
63
+
64
+ CONFIG['spkcache_input_len'] = CONFIG['spkcache_len']
65
+ CONFIG['fifo_input_len'] = CONFIG['fifo_len']
66
+ CONFIG['chunk_frames'] = (CONFIG['chunk_len'] + CONFIG['chunk_left_context'] + CONFIG['chunk_right_context']) * CONFIG['subsampling_factor']
67
+ CONFIG['preproc_audio_samples'] = (CONFIG['chunk_frames'] - 1) * CONFIG['mel_stride'] + CONFIG['mel_window']
68
+
69
+ class MicrophoneStream:
70
+ """Captures audio from microphone using sounddevice."""
71
+
72
+ def __init__(self, sample_rate, chunk_size, audio_queue):
73
+ self.sample_rate = sample_rate
74
+ self.chunk_size = chunk_size
75
+ self.audio_queue = audio_queue
76
+ self.stream = None
77
+ self.running = False
78
+
79
+ def start(self):
80
+ if not SOUNDDEVICE_AVAILABLE:
81
+ print("sounddevice not available!")
82
+ return False
83
+
84
+ def callback(indata, frames, time_info, status):
85
+ if status:
86
+ print(f"Audio status: {status}")
87
+ # indata is already float32 in range [-1, 1]
88
+ audio = indata[:, 0].copy() # Take first channel
89
+ self.audio_queue.put(audio)
90
+
91
+ self.stream = sd.InputStream(
92
+ samplerate=self.sample_rate,
93
+ channels=1,
94
+ dtype=np.float32,
95
+ blocksize=self.chunk_size,
96
+ callback=callback
97
+ )
98
+ self.stream.start()
99
+ self.running = True
100
+ print("Microphone started...")
101
+ return True
102
+
103
+ def stop(self):
104
+ self.running = False
105
+ if self.stream:
106
+ self.stream.stop()
107
+ self.stream.close()
108
+ print("Microphone stopped.")
109
+
110
+
111
+ class StreamingDiarizer:
112
+ """Real-time streaming diarization using CoreML."""
113
+
114
+ def __init__(self, nemo_model, preproc_model, main_model, config):
115
+ self.modules = nemo_model.sortformer_modules
116
+ self.preproc_model = preproc_model
117
+ self.main_model = main_model
118
+ self.config = config
119
+
120
+ # Audio buffer
121
+ self.audio_buffer = np.array([], dtype=np.float32)
122
+
123
+ # Feature buffer
124
+ self.feature_buffer = None
125
+ self.features_processed = 0
126
+
127
+ # Diarization state
128
+ self.state = self.modules.init_streaming_state(batch_size=1, device='cpu')
129
+ self.all_probs = [] # List of [T, 4] arrays
130
+
131
+ # Chunk tracking
132
+ self.diar_chunk_idx = 0
133
+ self.preproc_chunk_idx = 0
134
+
135
+ # Derived params
136
+ self.subsampling = config['subsampling_factor']
137
+ self.core_frames = config['chunk_len'] * self.subsampling
138
+ self.left_ctx = config['chunk_left_context'] * self.subsampling
139
+ self.right_ctx = config['chunk_right_context'] * self.subsampling
140
+
141
+ # Audio hop for preprocessor
142
+ self.audio_hop = config['preproc_audio_samples'] - config['mel_window']
143
+ self.overlap_frames = (config['mel_window'] - config['mel_stride']) // config['mel_stride'] + 1
144
+
145
+ def add_audio(self, audio_chunk):
146
+ """Add new audio samples."""
147
+ self.audio_buffer = np.concatenate([self.audio_buffer, audio_chunk])
148
+
149
+ def process(self):
150
+ """
151
+ Process available audio through preprocessor and diarizer.
152
+ Returns new probability frames if available.
153
+ """
154
+ new_probs = None
155
+
156
+ # Step 1: Run preprocessor on available audio
157
+ while len(self.audio_buffer) >= self.config['preproc_audio_samples']:
158
+ audio_chunk = self.audio_buffer[:self.config['preproc_audio_samples']]
159
+
160
+ preproc_inputs = {
161
+ "audio_signal": audio_chunk.reshape(1, -1).astype(np.float32),
162
+ "length": np.array([self.config['preproc_audio_samples']], dtype=np.int32)
163
+ }
164
+
165
+ preproc_out = self.preproc_model.predict(preproc_inputs)
166
+ feat_chunk = np.array(preproc_out["features"])
167
+ feat_len = int(preproc_out["feature_lengths"][0])
168
+
169
+ if self.preproc_chunk_idx == 0:
170
+ valid_feats = feat_chunk[:, :, :feat_len]
171
+ else:
172
+ valid_feats = feat_chunk[:, :, self.overlap_frames:feat_len]
173
+
174
+ if self.feature_buffer is None:
175
+ self.feature_buffer = valid_feats
176
+ else:
177
+ self.feature_buffer = np.concatenate([self.feature_buffer, valid_feats], axis=2)
178
+
179
+ self.audio_buffer = self.audio_buffer[self.audio_hop:]
180
+ self.preproc_chunk_idx += 1
181
+
182
+ if self.feature_buffer is None:
183
+ return None
184
+
185
+ # Step 2: Run diarization on available features
186
+ total_features = self.feature_buffer.shape[2]
187
+
188
+ while True:
189
+ # Calculate chunk boundaries
190
+ chunk_start = self.diar_chunk_idx * self.core_frames
191
+ chunk_end = chunk_start + self.core_frames
192
+
193
+ # Need right context
194
+ required_features = chunk_end + self.right_ctx
195
+
196
+ if required_features > total_features:
197
+ break # Not enough features yet
198
+
199
+ # Extract with context
200
+ left_offset = min(self.left_ctx, chunk_start)
201
+ right_offset = min(self.right_ctx, total_features - chunk_end)
202
+
203
+ feat_start = chunk_start - left_offset
204
+ feat_end = chunk_end + right_offset
205
+
206
+ chunk_feat = self.feature_buffer[:, :, feat_start:feat_end]
207
+ chunk_feat_tensor = torch.from_numpy(chunk_feat).float()
208
+ actual_len = chunk_feat.shape[2]
209
+
210
+ # Transpose to [B, T, D]
211
+ chunk_t = chunk_feat_tensor.transpose(1, 2)
212
+
213
+ # Pad if needed
214
+ if actual_len < self.config['chunk_frames']:
215
+ pad_len = self.config['chunk_frames'] - actual_len
216
+ chunk_in = torch.nn.functional.pad(chunk_t, (0, 0, 0, pad_len))
217
+ else:
218
+ chunk_in = chunk_t[:, :self.config['chunk_frames'], :]
219
+
220
+ # State preparation
221
+ curr_spk_len = self.state.spkcache.shape[1]
222
+ curr_fifo_len = self.state.fifo.shape[1]
223
+
224
+ current_spkcache = self.state.spkcache
225
+ if curr_spk_len < self.config['spkcache_input_len']:
226
+ current_spkcache = torch.nn.functional.pad(
227
+ current_spkcache, (0, 0, 0, self.config['spkcache_input_len'] - curr_spk_len)
228
+ )
229
+ elif curr_spk_len > self.config['spkcache_input_len']:
230
+ current_spkcache = current_spkcache[:, :self.config['spkcache_input_len'], :]
231
+
232
+ current_fifo = self.state.fifo
233
+ if curr_fifo_len < self.config['fifo_input_len']:
234
+ current_fifo = torch.nn.functional.pad(
235
+ current_fifo, (0, 0, 0, self.config['fifo_input_len'] - curr_fifo_len)
236
+ )
237
+ elif curr_fifo_len > self.config['fifo_input_len']:
238
+ current_fifo = current_fifo[:, :self.config['fifo_input_len'], :]
239
+
240
+ # CoreML inference
241
+ coreml_inputs = {
242
+ "chunk": chunk_in.numpy().astype(np.float32),
243
+ "chunk_lengths": np.array([actual_len], dtype=np.int32),
244
+ "spkcache": current_spkcache.numpy().astype(np.float32),
245
+ "spkcache_lengths": np.array([curr_spk_len], dtype=np.int32),
246
+ "fifo": current_fifo.numpy().astype(np.float32),
247
+ "fifo_lengths": np.array([curr_fifo_len], dtype=np.int32)
248
+ }
249
+
250
+ st_time = time.time_ns()
251
+ coreml_out = self.main_model.predict(coreml_inputs)
252
+ ed_time = time.time_ns()
253
+ print(f"duration: {1e-6 * (ed_time - st_time)}")
254
+
255
+ pred_logits = torch.from_numpy(coreml_out["speaker_preds"])
256
+ chunk_embs = torch.from_numpy(coreml_out["chunk_pre_encoder_embs"])
257
+ chunk_emb_len = int(coreml_out["chunk_pre_encoder_lengths"][0])
258
+
259
+ chunk_embs = chunk_embs[:, :chunk_emb_len, :]
260
+
261
+ lc = round(left_offset / self.subsampling)
262
+ rc = math.ceil(right_offset / self.subsampling)
263
+
264
+ self.state, chunk_probs = self.modules.streaming_update(
265
+ streaming_state=self.state,
266
+ chunk=chunk_embs,
267
+ preds=pred_logits,
268
+ lc=lc,
269
+ rc=rc
270
+ )
271
+
272
+ # Store probabilities
273
+ probs_np = chunk_probs.squeeze(0).detach().cpu().numpy()
274
+ self.all_probs.append(probs_np)
275
+
276
+ new_probs = probs_np
277
+ self.diar_chunk_idx += 1
278
+
279
+ return new_probs
280
+
281
+ def get_all_probs(self):
282
+ """Get all accumulated probabilities."""
283
+ if len(self.all_probs) > 0:
284
+ return np.concatenate(self.all_probs, axis=0)
285
+ return None
286
+
287
+
288
+ def run_mic_inference(model_name, coreml_dir):
289
+ """Run real-time microphone diarization."""
290
+
291
+ if not SOUNDDEVICE_AVAILABLE:
292
+ print("Cannot run mic inference without sounddevice!")
293
+ return
294
+
295
+ print("=" * 70)
296
+ print("Real-Time Microphone Diarization")
297
+ print("=" * 70)
298
+
299
+ # Load NeMo model
300
+ print(f"\nLoading NeMo Model: {model_name}")
301
+ nemo_model = SortformerEncLabelModel.from_pretrained(model_name, map_location="cpu")
302
+ nemo_model.eval()
303
+
304
+ # Configure
305
+ modules = nemo_model.sortformer_modules
306
+ modules.chunk_len = CONFIG['chunk_len']
307
+ modules.chunk_right_context = CONFIG['chunk_right_context']
308
+ modules.chunk_left_context = CONFIG['chunk_left_context']
309
+ modules.fifo_len = CONFIG['fifo_len']
310
+ modules.spkcache_len = CONFIG['spkcache_len']
311
+ modules.spkcache_update_period = CONFIG['spkcache_update_period']
312
+
313
+ if hasattr(nemo_model.preprocessor, 'featurizer'):
314
+ nemo_model.preprocessor.featurizer.dither = 0.0
315
+ nemo_model.preprocessor.featurizer.pad_to = 0
316
+
317
+ # Load CoreML models
318
+ print(f"Loading CoreML Models from {coreml_dir}...")
319
+ preproc_model = ct.models.MLModel(
320
+ os.path.join(coreml_dir, "Pipeline_Preprocessor.mlpackage"),
321
+ compute_units=ct.ComputeUnit.CPU_ONLY
322
+ )
323
+ main_model = ct.models.MLModel(
324
+ os.path.join(coreml_dir, "SortformerPipeline.mlpackage"),
325
+ compute_units=ct.ComputeUnit.ALL
326
+ )
327
+
328
+ # Create diarizer
329
+ diarizer = StreamingDiarizer(nemo_model, preproc_model, main_model, CONFIG)
330
+
331
+ # Audio queue
332
+ audio_queue = queue.Queue()
333
+
334
+ # Start microphone
335
+ mic = MicrophoneStream(
336
+ sample_rate=CONFIG['sample_rate'],
337
+ chunk_size=CONFIG['audio_chunk_samples'],
338
+ audio_queue=audio_queue
339
+ )
340
+
341
+ if not mic.start():
342
+ return
343
+
344
+ # Setup plot
345
+ plt.ion()
346
+ fig, ax = plt.subplots(figsize=(14, 4))
347
+
348
+ print("\nListening... Press Ctrl+C to stop.\n")
349
+
350
+ try:
351
+ last_update = time.time()
352
+
353
+ while True:
354
+ # Get audio from queue
355
+ while not audio_queue.empty():
356
+ audio_chunk = audio_queue.get()
357
+ diarizer.add_audio(audio_chunk)
358
+
359
+ # Process
360
+ new_probs = diarizer.process()
361
+
362
+ # Update plot periodically
363
+ if time.time() - last_update > 0.16: # Update every 160ms
364
+ all_probs = diarizer.get_all_probs()
365
+
366
+ if all_probs is not None and len(all_probs) > 0:
367
+ ax.clear()
368
+
369
+ # Show last 200 frames (~16 seconds)
370
+ display_frames = min(200, len(all_probs))
371
+ display_probs = all_probs[-display_frames:]
372
+
373
+ sns.heatmap(
374
+ display_probs.T,
375
+ ax=ax,
376
+ cmap="viridis",
377
+ vmin=0, vmax=1,
378
+ yticklabels=[f"Spk {i}" for i in range(4)],
379
+ cbar=False
380
+ )
381
+
382
+ ax.set_xlabel("Time (frames, 80ms each)")
383
+ ax.set_ylabel("Speaker")
384
+ ax.set_title(f"Live Diarization - Total: {len(all_probs)} frames ({len(all_probs)*0.08:.1f}s)")
385
+
386
+ plt.draw()
387
+ plt.pause(0.01)
388
+
389
+ last_update = time.time()
390
+
391
+ time.sleep(0.01)
392
+
393
+ except KeyboardInterrupt:
394
+ print("\nStopping...")
395
+ finally:
396
+ mic.stop()
397
+ plt.ioff()
398
+ plt.close()
399
+
400
+ # Final summary
401
+ all_probs = diarizer.get_all_probs()
402
+ if all_probs is not None:
403
+ print(f"\nTotal processed: {len(all_probs)} frames ({len(all_probs)*0.08:.1f} seconds)")
404
+
405
+
406
+ def run_file_demo(model_name, coreml_dir, audio_path):
407
+ """Run demo on audio file with live updating plot."""
408
+
409
+ print("=" * 70)
410
+ print("File Demo with Live Updating Plot")
411
+ print("=" * 70)
412
+
413
+ # Load NeMo model
414
+ print(f"\nLoading NeMo Model: {model_name}")
415
+ nemo_model = SortformerEncLabelModel.from_pretrained(model_name, map_location="cpu")
416
+ nemo_model.eval()
417
+
418
+ # Configure
419
+ modules = nemo_model.sortformer_modules
420
+ modules.chunk_len = CONFIG['chunk_len']
421
+ modules.chunk_right_context = CONFIG['chunk_right_context']
422
+ modules.chunk_left_context = CONFIG['chunk_left_context']
423
+ modules.fifo_len = CONFIG['fifo_len']
424
+ modules.spkcache_len = CONFIG['spkcache_len']
425
+ modules.spkcache_update_period = CONFIG['spkcache_update_period']
426
+
427
+ if hasattr(nemo_model.preprocessor, 'featurizer'):
428
+ nemo_model.preprocessor.featurizer.dither = 0.0
429
+ nemo_model.preprocessor.featurizer.pad_to = 0
430
+
431
+ # Load CoreML models
432
+ print(f"Loading CoreML Models from {coreml_dir}...")
433
+ preproc_model = ct.models.MLModel(
434
+ os.path.join(coreml_dir, "Pipeline_Preprocessor.mlpackage"),
435
+ compute_units=ct.ComputeUnit.CPU_ONLY
436
+ )
437
+ main_model = ct.models.MLModel(
438
+ os.path.join(coreml_dir, "SortformerPipeline.mlpackage"),
439
+ compute_units=ct.ComputeUnit.ALL
440
+ )
441
+
442
+ # Load audio file
443
+ import librosa
444
+ audio, _ = librosa.load(audio_path, sr=CONFIG['sample_rate'], mono=True)
445
+ print(f"Loaded audio: {len(audio)} samples ({len(audio)/CONFIG['sample_rate']:.1f}s)")
446
+
447
+ # Create diarizer
448
+ diarizer = StreamingDiarizer(nemo_model, preproc_model, main_model, CONFIG)
449
+
450
+ # Setup plot
451
+ plt.ion()
452
+ fig, ax = plt.subplots(figsize=(14, 4))
453
+
454
+ # Simulate streaming
455
+ chunk_size = CONFIG['audio_chunk_samples']
456
+ offset = 0
457
+
458
+ print("\nStreaming audio with live plot...")
459
+
460
+ try:
461
+ while offset < len(audio):
462
+ # Add audio chunk
463
+ chunk_end = min(offset + chunk_size, len(audio))
464
+ audio_chunk = audio[offset:chunk_end]
465
+ diarizer.add_audio(audio_chunk)
466
+ offset = chunk_end
467
+
468
+ # Process
469
+ diarizer.process()
470
+
471
+ # Update plot
472
+ all_probs = diarizer.get_all_probs()
473
+
474
+ if all_probs is not None and len(all_probs) > 0:
475
+ ax.clear()
476
+
477
+ sns.heatmap(
478
+ all_probs.T,
479
+ ax=ax,
480
+ cmap="viridis",
481
+ vmin=0, vmax=1,
482
+ yticklabels=[f"Spk {i}" for i in range(4)],
483
+ cbar=False
484
+ )
485
+
486
+ ax.set_xlabel("Time (frames, 80ms each)")
487
+ ax.set_ylabel("Speaker")
488
+ ax.set_title(f"Streaming Diarization - {len(all_probs)} frames")
489
+
490
+ plt.draw()
491
+ plt.pause(0.05)
492
+
493
+ # Simulate real-time (optional - comment out for fast mode)
494
+ # time.sleep(chunk_size / CONFIG['sample_rate'])
495
+
496
+ except KeyboardInterrupt:
497
+ print("\nStopped.")
498
+
499
+ plt.ioff()
500
+
501
+ # Final plot
502
+ all_probs = diarizer.get_all_probs()
503
+ if all_probs is not None:
504
+ print(f"\nTotal: {len(all_probs)} frames ({len(all_probs)*0.08:.1f}s)")
505
+ plt.show()
506
+
507
+
508
+ if __name__ == "__main__":
509
+ parser = argparse.ArgumentParser()
510
+ parser.add_argument("--model_name", default="nvidia/diar_streaming_sortformer_4spk-v2.1")
511
+ parser.add_argument("--coreml_dir", default="coreml_models")
512
+ parser.add_argument("--audio_path", default="audio.wav")
513
+ parser.add_argument("--mic", action="store_true", help="Use microphone input")
514
+ args = parser.parse_args()
515
+
516
+ run_mic_inference(args.model_name, args.coreml_dir)
517
+ # if args.mic:
518
+ # else:
519
+ # run_file_demo(args.model_name, args.coreml_dir, args.audio_path)
nemo_streaming_reference.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Get exact NeMo streaming inference output for comparison with Swift."""
3
+
4
+ import os
5
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
6
+
7
+ import torch
8
+ import numpy as np
9
+ import librosa
10
+ import json
11
+
12
+ from nemo.collections.asr.models import SortformerEncLabelModel
13
+
14
+ def main():
15
+ print("Loading NeMo model...")
16
+ model = SortformerEncLabelModel.restore_from(
17
+ 'diar_streaming_sortformer_4spk-v2.nemo', map_location='cpu'
18
+ )
19
+ model.eval()
20
+
21
+ # Disable dither for deterministic output
22
+ if hasattr(model.preprocessor, 'featurizer'):
23
+ if hasattr(model.preprocessor.featurizer, 'dither'):
24
+ model.preprocessor.featurizer.dither = 0.0
25
+
26
+ # Configure for Gradient Descent's streaming config (same as Swift)
27
+ modules = model.sortformer_modules
28
+ modules.chunk_len = 6
29
+ modules.chunk_left_context = 1
30
+ modules.chunk_right_context = 7
31
+ modules.fifo_len = 40
32
+ modules.spkcache_len = 188
33
+ modules.spkcache_update_period = 31
34
+
35
+ print(f"Config: chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, right_ctx={modules.chunk_right_context}")
36
+ print(f" fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}")
37
+
38
+ # Load audio
39
+ audio_path = "../audio.wav"
40
+ audio, sr = librosa.load(audio_path, sr=16000, mono=True)
41
+ print(f"Loaded audio: {len(audio)} samples ({len(audio)/16000:.2f}s)")
42
+
43
+ waveform = torch.from_numpy(audio).unsqueeze(0).float()
44
+
45
+ # Get mel features using model's preprocessor
46
+ with torch.no_grad():
47
+ audio_len = torch.tensor([waveform.shape[1]])
48
+ features, feat_len = model.process_signal(
49
+ audio_signal=waveform, audio_signal_length=audio_len
50
+ )
51
+
52
+ # features is [batch, mel, time], need [batch, time, mel] for streaming
53
+ features = features[:, :, :feat_len.max()]
54
+ print(f"Features: {features.shape} (batch, mel, time)")
55
+
56
+ # Streaming inference using forward_streaming_step
57
+ subsampling = modules.subsampling_factor # 8
58
+ chunk_len = modules.chunk_len # 6
59
+ left_context = modules.chunk_left_context # 1
60
+ right_context = modules.chunk_right_context # 7
61
+ core_frames = chunk_len * subsampling # 48 mel frames
62
+
63
+ total_mel_frames = features.shape[2]
64
+ print(f"Total mel frames: {total_mel_frames}")
65
+ print(f"Core frames per chunk: {core_frames}")
66
+
67
+ # Initialize streaming state
68
+ streaming_state = modules.init_streaming_state(device=features.device)
69
+
70
+ # Initialize total_preds tensor
71
+ total_preds = torch.zeros((1, 0, 4), device=features.device)
72
+
73
+ all_preds = []
74
+ chunk_idx = 0
75
+
76
+ # Process chunks like streaming_feat_loader
77
+ stt_feat = 0
78
+ while stt_feat < total_mel_frames:
79
+ end_feat = min(stt_feat + core_frames, total_mel_frames)
80
+
81
+ # Calculate context (in mel frames)
82
+ left_offset = min(left_context * subsampling, stt_feat)
83
+ right_offset = min(right_context * subsampling, total_mel_frames - end_feat)
84
+
85
+ chunk_start = stt_feat - left_offset
86
+ chunk_end = end_feat + right_offset
87
+
88
+ # Extract chunk - [batch, mel, time] -> [batch, time, mel]
89
+ chunk = features[:, :, chunk_start:chunk_end] # [1, 128, T]
90
+ chunk_t = chunk.transpose(1, 2) # [1, T, 128]
91
+ chunk_len_tensor = torch.tensor([chunk_t.shape[1]], dtype=torch.long)
92
+
93
+ with torch.no_grad():
94
+ # Use forward_streaming_step
95
+ streaming_state, total_preds = model.forward_streaming_step(
96
+ processed_signal=chunk_t,
97
+ processed_signal_length=chunk_len_tensor,
98
+ streaming_state=streaming_state,
99
+ total_preds=total_preds,
100
+ left_offset=left_offset,
101
+ right_offset=right_offset,
102
+ )
103
+
104
+ chunk_idx += 1
105
+ stt_feat = end_feat
106
+
107
+ # total_preds now contains all predictions
108
+ all_preds = total_preds[0].numpy() # [total_frames, 4]
109
+ print(f"\nTotal output frames: {all_preds.shape[0]}")
110
+ print(f"Predictions shape: {all_preds.shape}")
111
+
112
+ # Print timeline
113
+ print("\n=== NeMo Streaming Timeline (80ms per frame, threshold=0.55) ===")
114
+ print("Frame Time Spk0 Spk1 Spk2 Spk3 | Visual")
115
+ print("-" * 60)
116
+
117
+ for frame in range(all_preds.shape[0]):
118
+ time_sec = frame * 0.08
119
+ probs = all_preds[frame]
120
+ visual = ['■' if p > 0.55 else '·' for p in probs]
121
+ print(f"{frame:5d} {time_sec:5.2f}s {probs[0]:.3f} {probs[1]:.3f} {probs[2]:.3f} {probs[3]:.3f} | [{visual[0]}{visual[1]}{visual[2]}{visual[3]}]")
122
+
123
+ print("-" * 60)
124
+
125
+ # Speaker activity summary
126
+ print("\n=== Speaker Activity Summary ===")
127
+ threshold = 0.55
128
+ for spk in range(4):
129
+ active_frames = np.sum(all_preds[:, spk] > threshold)
130
+ active_time = active_frames * 0.08
131
+ percent = active_time / (all_preds.shape[0] * 0.08) * 100
132
+ print(f"Speaker_{spk}: {active_time:.1f}s active ({percent:.1f}%)")
133
+
134
+ # Save to JSON for comparison
135
+ output = {
136
+ "total_frames": int(all_preds.shape[0]),
137
+ "frame_duration_seconds": 0.08,
138
+ "probabilities": all_preds.flatten().tolist(),
139
+ "config": {
140
+ "chunk_len": chunk_len,
141
+ "chunk_left_context": left_context,
142
+ "chunk_right_context": right_context,
143
+ "fifo_len": modules.fifo_len,
144
+ "spkcache_len": modules.spkcache_len,
145
+ }
146
+ }
147
+
148
+ with open("/tmp/nemo_streaming_reference.json", "w") as f:
149
+ json.dump(output, f, indent=2)
150
+ print("\nSaved to /tmp/nemo_streaming_reference.json")
151
+
152
+ if __name__ == "__main__":
153
+ main()
streaming_inference.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import coremltools as ct
4
+ import librosa
5
+ import argparse
6
+ import os
7
+ import sys
8
+ import math
9
+
10
+ # Import NeMo components for State Logic
11
+ try:
12
+ from nemo.collections.asr.models import SortformerEncLabelModel
13
+ # Try importing SortformerModules directly for type hints if needed, but we can access via model instance
14
+ from nemo.collections.asr.modules.sortformer_modules import SortformerModules
15
+ except ImportError as e:
16
+ print(f"Error importing NeMo: {e}")
17
+ sys.exit(1)
18
+
19
+
20
+ def streaming_feat_loader(modules, feat_seq, feat_seq_length, feat_seq_offset):
21
+ """
22
+ Load a chunk of feature sequence for streaming inference.
23
+ Adapted from NeMo's SortformerModules.streaming_feat_loader
24
+
25
+ Args:
26
+ modules: SortformerModules instance with chunk_len, subsampling_factor,
27
+ chunk_left_context, chunk_right_context
28
+ feat_seq (torch.Tensor): Tensor containing feature sequence
29
+ Shape: (batch_size, feat_dim, feat frame count)
30
+ feat_seq_length (torch.Tensor): Tensor containing feature sequence lengths
31
+ Shape: (batch_size,)
32
+ feat_seq_offset (torch.Tensor): Tensor containing feature sequence offsets
33
+ Shape: (batch_size,)
34
+
35
+ Yields:
36
+ chunk_idx (int): Index of the current chunk
37
+ chunk_feat_seq (torch.Tensor): Tensor containing the chunk of feature sequence
38
+ Shape: (batch_size, feat frame count, feat_dim) # Transposed!
39
+ feat_lengths (torch.Tensor): Tensor containing lengths of the chunk of feature sequence
40
+ Shape: (batch_size,)
41
+ left_offset (int): Left context offset in feature frames
42
+ right_offset (int): Right context offset in feature frames
43
+ """
44
+ feat_len = feat_seq.shape[2]
45
+ chunk_len = modules.chunk_len
46
+ subsampling_factor = modules.subsampling_factor
47
+ chunk_left_context = getattr(modules, 'chunk_left_context', 0)
48
+ chunk_right_context = getattr(modules, 'chunk_right_context', 0)
49
+
50
+ num_chunks = math.ceil(feat_len / (chunk_len * subsampling_factor))
51
+ print(f"streaming_feat_loader: feat_len={feat_len}, num_chunks={num_chunks}, "
52
+ f"chunk_len={chunk_len}, subsampling_factor={subsampling_factor}")
53
+
54
+ stt_feat, end_feat, chunk_idx = 0, 0, 0
55
+ while end_feat < feat_len:
56
+ left_offset = min(chunk_left_context * subsampling_factor, stt_feat)
57
+ end_feat = min(stt_feat + chunk_len * subsampling_factor, feat_len)
58
+ right_offset = min(chunk_right_context * subsampling_factor, feat_len - end_feat)
59
+
60
+ chunk_feat_seq = feat_seq[:, :, stt_feat - left_offset : end_feat + right_offset]
61
+ feat_lengths = (feat_seq_length + feat_seq_offset - stt_feat + left_offset).clamp(
62
+ 0, chunk_feat_seq.shape[2]
63
+ )
64
+ feat_lengths = feat_lengths * (feat_seq_offset < end_feat)
65
+ stt_feat = end_feat
66
+
67
+ # Transpose from (batch, feat_dim, frames) to (batch, frames, feat_dim)
68
+ chunk_feat_seq_t = torch.transpose(chunk_feat_seq, 1, 2)
69
+
70
+ print(f" chunk_idx: {chunk_idx}, chunk_feat_seq_t shape: {chunk_feat_seq_t.shape}, "
71
+ f"feat_lengths: {feat_lengths}, left_offset: {left_offset}, right_offset: {right_offset}")
72
+
73
+ yield chunk_idx, chunk_feat_seq_t, feat_lengths, left_offset, right_offset
74
+ chunk_idx += 1
75
+
76
+
77
+ def run_streaming_inference(model_name, coreml_dir, audio_path):
78
+ print(f"Loading NeMo Model (for Python Streaming Logic): {model_name}")
79
+ if os.path.exists(model_name):
80
+ nemo_model = SortformerEncLabelModel.restore_from(model_name, map_location="cpu")
81
+ else:
82
+ nemo_model = SortformerEncLabelModel.from_pretrained(model_name, map_location="cpu")
83
+ nemo_model.eval()
84
+ modules = nemo_model.sortformer_modules
85
+
86
+ # --- Override Config to match CoreML Export (Low Latency) ---
87
+ print("Overriding Config (Inference) to match CoreML...")
88
+ modules.chunk_len = 4
89
+ modules.chunk_right_context = 1 # 1 chunk of right context
90
+ modules.chunk_left_context = 2 # 1 chunk of left context
91
+ # Match CoreML export sizes (from model spec)
92
+ modules.fifo_len = 63
93
+ modules.spkcache_len = 63
94
+ modules.spkcache_update_period = 50 # Match CoreML export
95
+
96
+ # CoreML fixed input sizes (must match export settings)
97
+ # With left_context=1, right_context=1: (4+1+1)*8 = 48 frames
98
+ COREML_CHUNK_FRAMES = 56
99
+ COREML_SPKCACHE_LEN = 63
100
+ COREML_FIFO_LEN = 63
101
+
102
+ # Disable dither and pad_to (as diarize does)
103
+ if hasattr(nemo_model.preprocessor, 'featurizer'):
104
+ if hasattr(nemo_model.preprocessor.featurizer, 'dither'):
105
+ nemo_model.preprocessor.featurizer.dither = 0.0
106
+ if hasattr(nemo_model.preprocessor.featurizer, 'pad_to'):
107
+ nemo_model.preprocessor.featurizer.pad_to = 0
108
+
109
+ # CoreML Models - use CPU_ONLY for compatibility
110
+ print(f"Loading CoreML Models from {coreml_dir}...")
111
+ preproc_model = ct.models.MLModel(
112
+ os.path.join(coreml_dir, "SortformerPreprocessor.mlpackage"),
113
+ compute_units=ct.ComputeUnit.CPU_ONLY
114
+ )
115
+ main_model = ct.models.MLModel(
116
+ os.path.join(coreml_dir, "Sortformer.mlpackage"),
117
+ compute_units=ct.ComputeUnit.ALL
118
+ )
119
+
120
+ # Config
121
+ chunk_len = modules.chunk_len # Output frames (e.g., 4 for low latency)
122
+ subsampling_factor = modules.subsampling_factor # 8
123
+ sample_rate = 16000
124
+
125
+ print(f"Chunk Config: {chunk_len} output frames (diar), subsampling_factor={subsampling_factor}")
126
+
127
+ # Load Audio
128
+ print(f"Loading Audio: {audio_path}")
129
+ full_audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True)
130
+ total_samples = len(full_audio)
131
+ print(f"Total Samples: {total_samples} ({total_samples/sample_rate:.2f}s)")
132
+
133
+ # === Step 1: Extract features for the ENTIRE audio using preprocessor ===
134
+ # This matches NeMo's approach: process_signal -> forward_streaming
135
+ print("Extracting features for entire audio...")
136
+ audio_tensor = torch.from_numpy(full_audio).unsqueeze(0).float() # [1, samples]
137
+ audio_length = torch.tensor([total_samples], dtype=torch.long)
138
+
139
+ with torch.no_grad():
140
+ # Use process_signal for proper normalization (same as forward())
141
+ processed_signal, processed_signal_length = nemo_model.process_signal(
142
+ audio_signal=audio_tensor, audio_signal_length=audio_length
143
+ )
144
+
145
+ print(f"Processed signal shape: {processed_signal.shape}") # [1, 128, T]
146
+ print(f"Processed signal length: {processed_signal_length}")
147
+
148
+ # Trim to actual length
149
+ processed_signal = processed_signal[:, :, :processed_signal_length.max()]
150
+
151
+ # === Step 2: Initialize streaming state ===
152
+ print("Initializing Streaming State...")
153
+ state = modules.init_streaming_state(batch_size=1, device='cpu')
154
+
155
+ # === Step 3: Use streaming_feat_loader to chunk features (matches NeMo exactly) ===
156
+ batch_size = processed_signal.shape[0]
157
+ processed_signal_offset = torch.zeros((batch_size,), dtype=torch.long)
158
+
159
+ all_preds = []
160
+
161
+ feat_loader = streaming_feat_loader(
162
+ modules=modules,
163
+ feat_seq=processed_signal,
164
+ feat_seq_length=processed_signal_length,
165
+ feat_seq_offset=processed_signal_offset,
166
+ )
167
+
168
+ for chunk_idx, chunk_feat_seq_t, feat_lengths, left_offset, right_offset in feat_loader:
169
+ # Prepare inputs for CoreML model
170
+ # Pad chunk to fixed size for CoreML
171
+ chunk_actual_len = chunk_feat_seq_t.shape[1]
172
+ if chunk_actual_len < COREML_CHUNK_FRAMES:
173
+ pad_len = COREML_CHUNK_FRAMES - chunk_actual_len
174
+ chunk_in = torch.nn.functional.pad(chunk_feat_seq_t, (0, 0, 0, pad_len))
175
+ else:
176
+ chunk_in = chunk_feat_seq_t[:, :COREML_CHUNK_FRAMES, :]
177
+ chunk_len_in = feat_lengths.long() # actual length
178
+
179
+ # Get actual lengths from state (pad tensors but track real lengths)
180
+ curr_spk_len = state.spkcache.shape[1]
181
+ curr_fifo_len = state.fifo.shape[1]
182
+ # Prepare SpkCache - Pad to CoreML fixed size
183
+ current_spkcache = state.spkcache
184
+
185
+ if curr_spk_len < COREML_SPKCACHE_LEN:
186
+ pad_len = COREML_SPKCACHE_LEN - curr_spk_len
187
+ current_spkcache = torch.nn.functional.pad(current_spkcache, (0, 0, 0, pad_len))
188
+ elif curr_spk_len > COREML_SPKCACHE_LEN:
189
+ current_spkcache = current_spkcache[:, :COREML_SPKCACHE_LEN, :]
190
+
191
+ spkcache_in = current_spkcache
192
+ # Use actual length, not padded length
193
+ spkcache_len_in = torch.tensor([curr_spk_len], dtype=torch.long)
194
+
195
+ # Prepare FIFO - Pad to CoreML fixed size
196
+ current_fifo = state.fifo
197
+
198
+ if curr_fifo_len < COREML_FIFO_LEN:
199
+ pad_len = COREML_FIFO_LEN - curr_fifo_len
200
+ current_fifo = torch.nn.functional.pad(current_fifo, (0, 0, 0, pad_len))
201
+ elif curr_fifo_len > COREML_FIFO_LEN:
202
+ current_fifo = current_fifo[:, :COREML_FIFO_LEN, :]
203
+
204
+ fifo_in = current_fifo
205
+ fifo_len_in = torch.tensor([curr_fifo_len], dtype=torch.long)
206
+
207
+ # === Run CoreML Model ===
208
+ coreml_inputs = {
209
+ "chunk": chunk_in.numpy().astype(np.float32),
210
+ "chunk_lengths": chunk_len_in.numpy().astype(np.int32),
211
+ "spkcache": spkcache_in.numpy().astype(np.float32),
212
+ "spkcache_lengths": spkcache_len_in.numpy().astype(np.int32),
213
+ "fifo": fifo_in.numpy().astype(np.float32),
214
+ "fifo_lengths": fifo_len_in.numpy().astype(np.int32)
215
+ }
216
+
217
+ coreml_out = main_model.predict(coreml_inputs)
218
+
219
+ # Convert outputs back to torch tensors
220
+ pred_logits = torch.from_numpy(coreml_out["speaker_preds"])
221
+ chunk_embs = torch.from_numpy(coreml_out["chunk_pre_encoder_embs"])
222
+ chunk_emb_len = int(coreml_out["chunk_pre_encoder_lengths"][0])
223
+
224
+ # Trim chunk_embs to actual length (drop padded frames)
225
+ chunk_embs = chunk_embs[:, :chunk_emb_len, :]
226
+
227
+ # Compute lc and rc for streaming_update (in embeddings/diar frames, not feature frames)
228
+ # NeMo does: lc = round(left_offset / encoder.subsampling_factor)
229
+ # rc = math.ceil(right_offset / encoder.subsampling_factor)
230
+ lc = round(left_offset / subsampling_factor)
231
+ rc = math.ceil(right_offset / subsampling_factor)
232
+
233
+ # Update state using streaming_update with proper lc/rc
234
+ state, chunk_probs = modules.streaming_update(
235
+ streaming_state=state,
236
+ chunk=chunk_embs,
237
+ preds=pred_logits,
238
+ lc=lc,
239
+ rc=rc
240
+ )
241
+
242
+ # chunk_probs is the prediction for the current chunk
243
+ all_preds.append(chunk_probs)
244
+
245
+ print(f"Processed chunk {chunk_idx + 1}, chunk_probs shape: {chunk_probs.shape}", end='\r')
246
+
247
+ print(f"\nFinished. Total Chunks: {len(all_preds)}")
248
+ if len(all_preds) > 0:
249
+ final_probs = torch.cat(all_preds, dim=1) # [1, TotalFrames, Spks]
250
+ print(f"Final Predictions Shape: {final_probs.shape}")
251
+ return final_probs
252
+ return None
253
+
254
+
255
+ if __name__ == "__main__":
256
+ parser = argparse.ArgumentParser()
257
+ parser.add_argument("--model_name", default="nvidia/diar_streaming_sortformer_4spk-v2.1")
258
+ parser.add_argument("--coreml_dir", default="coreml_models")
259
+ parser.add_argument("--audio_path", default="test2.wav")
260
+ args = parser.parse_args()
261
+
262
+ run_streaming_inference(args.model_name, args.coreml_dir, args.audio_path)
streaming_preproc_inference.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ True Streaming CoreML Diarization
3
+
4
+ This script implements true streaming inference:
5
+ Audio chunks → CoreML Preprocessor → Feature Buffer → CoreML Main Model → Predictions
6
+
7
+ Audio is processed incrementally, features are accumulated with proper context handling.
8
+ """
9
+ import os
10
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
11
+
12
+ import torch
13
+ import numpy as np
14
+ import coremltools as ct
15
+ import librosa
16
+ import argparse
17
+ import math
18
+
19
+ # Import NeMo for state management (streaming_update) only
20
+ from nemo.collections.asr.models import SortformerEncLabelModel
21
+
22
+
23
+ # ============================================================
24
+ # Configuration for Sortformer16.mlpackage
25
+ # ============================================================
26
+ CONFIG = {
27
+ 'chunk_len': 4, # Diarization chunk length
28
+ 'chunk_right_context': 1, # Right context chunks
29
+ 'chunk_left_context': 2, # Left context chunks
30
+ 'fifo_len': 63,
31
+ 'spkcache_len': 63,
32
+ 'spkcache_update_period': 50,
33
+ 'subsampling_factor': 8,
34
+ 'sample_rate': 16000,
35
+
36
+ # Derived values
37
+ 'chunk_frames': 56, # (4+2+1)*8 = 56 feature frames for CoreML input
38
+ 'spkcache_input_len': 63,
39
+ 'fifo_input_len': 63,
40
+
41
+ # Preprocessor settings
42
+ 'preproc_audio_samples': 9200, # CoreML preprocessor fixed input size
43
+ 'mel_window': 400, # 25ms @ 16kHz
44
+ 'mel_stride': 160, # 10ms @ 16kHz
45
+ }
46
+
47
+
48
+ def run_true_streaming(nemo_model, preproc_model, main_model, audio_path, config):
49
+ """
50
+ True streaming inference: audio chunks → preproc → main model.
51
+
52
+ Strategy:
53
+ 1. Process audio in chunks through CoreML preprocessor
54
+ 2. Accumulate features
55
+ 3. When enough features for a diarization chunk (with context), run main model
56
+ """
57
+ modules = nemo_model.sortformer_modules
58
+ subsampling_factor = config['subsampling_factor']
59
+
60
+ # Load full audio (simulating microphone input)
61
+ full_audio, sr = librosa.load(audio_path, sr=config['sample_rate'], mono=True)
62
+ total_samples = len(full_audio)
63
+
64
+ print(f"Total audio samples: {total_samples}")
65
+
66
+ # Preprocessing parameters
67
+ mel_window = config['mel_window']
68
+ mel_stride = config['mel_stride']
69
+ preproc_len = config['preproc_audio_samples']
70
+
71
+ # Audio hop for preprocessor (to avoid overlap in features)
72
+ audio_hop = preproc_len - mel_window # 8800 samples
73
+
74
+ # Feature accumulator
75
+ all_features = []
76
+ audio_offset = 0
77
+ preproc_chunk_idx = 0
78
+
79
+ # Step 1: Process all audio through preprocessor to get features
80
+ print("Step 1: Extracting features via CoreML preprocessor...")
81
+ while audio_offset < total_samples:
82
+ # Get audio chunk
83
+ chunk_end = min(audio_offset + preproc_len, total_samples)
84
+ audio_chunk = full_audio[audio_offset:chunk_end]
85
+ actual_samples = len(audio_chunk)
86
+
87
+ # Pad if needed
88
+ if actual_samples < preproc_len:
89
+ audio_chunk = np.pad(audio_chunk, (0, preproc_len - actual_samples))
90
+
91
+ # Run preprocessor
92
+ preproc_inputs = {
93
+ "audio_signal": audio_chunk.reshape(1, -1).astype(np.float32),
94
+ "length": np.array([actual_samples], dtype=np.int32)
95
+ }
96
+
97
+ preproc_out = preproc_model.predict(preproc_inputs)
98
+ feat_chunk = np.array(preproc_out["features"]) # [1, 128, frames]
99
+ feat_len = int(preproc_out["feature_lengths"][0])
100
+
101
+ # Extract valid features and handle overlap
102
+ if preproc_chunk_idx == 0:
103
+ # First chunk: keep all
104
+ valid_feats = feat_chunk[:, :, :feat_len]
105
+ else:
106
+ # Subsequent: skip overlap frames
107
+ overlap_frames = (mel_window - mel_stride) // mel_stride + 1 # ~2-3 frames
108
+ valid_feats = feat_chunk[:, :, overlap_frames:feat_len]
109
+
110
+ all_features.append(valid_feats)
111
+
112
+ audio_offset += audio_hop
113
+ preproc_chunk_idx += 1
114
+
115
+ print(f"\r Processed audio chunk {preproc_chunk_idx}, features so far: {sum(f.shape[2] for f in all_features)}", end='')
116
+
117
+ print()
118
+
119
+ # Concatenate all features
120
+ full_features = np.concatenate(all_features, axis=2) # [1, 128, total_frames]
121
+ processed_signal = torch.from_numpy(full_features).float()
122
+ processed_signal_length = torch.tensor([full_features.shape[2]], dtype=torch.long)
123
+
124
+ print(f"Total features extracted: {processed_signal.shape}")
125
+
126
+ # Step 2: Run diarization streaming loop (same as NeMo reference)
127
+ print("Step 2: Running diarization streaming...")
128
+
129
+ state = modules.init_streaming_state(batch_size=1, device='cpu')
130
+ all_preds = []
131
+
132
+ feat_len = processed_signal.shape[2]
133
+ chunk_len = modules.chunk_len
134
+ left_ctx = modules.chunk_left_context
135
+ right_ctx = modules.chunk_right_context
136
+
137
+ stt_feat, end_feat, chunk_idx = 0, 0, 0
138
+
139
+ while end_feat < feat_len:
140
+ left_offset = min(left_ctx * subsampling_factor, stt_feat)
141
+ end_feat = min(stt_feat + chunk_len * subsampling_factor, feat_len)
142
+ right_offset = min(right_ctx * subsampling_factor, feat_len - end_feat)
143
+
144
+ # Extract chunk with context
145
+ chunk_feat = processed_signal[:, :, stt_feat - left_offset : end_feat + right_offset]
146
+ actual_len = chunk_feat.shape[2]
147
+
148
+ # Transpose to [B, T, D]
149
+ chunk_t = chunk_feat.transpose(1, 2)
150
+
151
+ # Pad to fixed size
152
+ if actual_len < config['chunk_frames']:
153
+ pad_len = config['chunk_frames'] - actual_len
154
+ chunk_in = torch.nn.functional.pad(chunk_t, (0, 0, 0, pad_len))
155
+ else:
156
+ chunk_in = chunk_t[:, :config['chunk_frames'], :]
157
+
158
+ # State preparation
159
+ curr_spk_len = state.spkcache.shape[1]
160
+ curr_fifo_len = state.fifo.shape[1]
161
+
162
+ current_spkcache = state.spkcache
163
+ if curr_spk_len < config['spkcache_input_len']:
164
+ current_spkcache = torch.nn.functional.pad(
165
+ current_spkcache, (0, 0, 0, config['spkcache_input_len'] - curr_spk_len)
166
+ )
167
+ elif curr_spk_len > config['spkcache_input_len']:
168
+ current_spkcache = current_spkcache[:, :config['spkcache_input_len'], :]
169
+
170
+ current_fifo = state.fifo
171
+ if curr_fifo_len < config['fifo_input_len']:
172
+ current_fifo = torch.nn.functional.pad(
173
+ current_fifo, (0, 0, 0, config['fifo_input_len'] - curr_fifo_len)
174
+ )
175
+ elif curr_fifo_len > config['fifo_input_len']:
176
+ current_fifo = current_fifo[:, :config['fifo_input_len'], :]
177
+
178
+ # CoreML inference
179
+ coreml_inputs = {
180
+ "chunk": chunk_in.numpy().astype(np.float32),
181
+ "chunk_lengths": np.array([actual_len], dtype=np.int32),
182
+ "spkcache": current_spkcache.numpy().astype(np.float32),
183
+ "spkcache_lengths": np.array([curr_spk_len], dtype=np.int32),
184
+ "fifo": current_fifo.numpy().astype(np.float32),
185
+ "fifo_lengths": np.array([curr_fifo_len], dtype=np.int32)
186
+ }
187
+
188
+ coreml_out = main_model.predict(coreml_inputs)
189
+
190
+ pred_logits = torch.from_numpy(coreml_out["speaker_preds"])
191
+ chunk_embs = torch.from_numpy(coreml_out["chunk_pre_encoder_embs"])
192
+ chunk_emb_len = int(coreml_out["chunk_pre_encoder_lengths"][0])
193
+
194
+ chunk_embs = chunk_embs[:, :chunk_emb_len, :]
195
+
196
+ lc = round(left_offset / subsampling_factor)
197
+ rc = math.ceil(right_offset / subsampling_factor)
198
+
199
+ state, chunk_probs = modules.streaming_update(
200
+ streaming_state=state,
201
+ chunk=chunk_embs,
202
+ preds=pred_logits,
203
+ lc=lc,
204
+ rc=rc
205
+ )
206
+
207
+ all_preds.append(chunk_probs)
208
+ stt_feat = end_feat
209
+ chunk_idx += 1
210
+
211
+ print(f"\r Diarization chunk {chunk_idx}", end='')
212
+
213
+ print()
214
+
215
+ if len(all_preds) > 0:
216
+ return torch.cat(all_preds, dim=1)
217
+ return None
218
+
219
+
220
+ def run_reference(nemo_model, main_model, audio_path, config):
221
+ """
222
+ Reference implementation using NeMo preprocessing.
223
+ """
224
+ modules = nemo_model.sortformer_modules
225
+ subsampling_factor = modules.subsampling_factor
226
+
227
+ # Load full audio
228
+ full_audio, _ = librosa.load(audio_path, sr=config['sample_rate'], mono=True)
229
+ audio_tensor = torch.from_numpy(full_audio).unsqueeze(0).float()
230
+ audio_length = torch.tensor([len(full_audio)], dtype=torch.long)
231
+
232
+ # Extract features using NeMo preprocessor
233
+ with torch.no_grad():
234
+ processed_signal, processed_signal_length = nemo_model.process_signal(
235
+ audio_signal=audio_tensor, audio_signal_length=audio_length
236
+ )
237
+ processed_signal = processed_signal[:, :, :processed_signal_length.max()]
238
+
239
+ print(f"NeMo Preproc: features shape = {processed_signal.shape}")
240
+
241
+ # Streaming loop
242
+ state = modules.init_streaming_state(batch_size=1, device='cpu')
243
+ all_preds = []
244
+
245
+ feat_len = processed_signal.shape[2]
246
+ chunk_len = modules.chunk_len
247
+ left_ctx = modules.chunk_left_context
248
+ right_ctx = modules.chunk_right_context
249
+
250
+ stt_feat, end_feat, chunk_idx = 0, 0, 0
251
+
252
+ while end_feat < feat_len:
253
+ left_offset = min(left_ctx * subsampling_factor, stt_feat)
254
+ end_feat = min(stt_feat + chunk_len * subsampling_factor, feat_len)
255
+ right_offset = min(right_ctx * subsampling_factor, feat_len - end_feat)
256
+
257
+ chunk_feat = processed_signal[:, :, stt_feat - left_offset : end_feat + right_offset]
258
+ actual_len = chunk_feat.shape[2]
259
+
260
+ chunk_t = chunk_feat.transpose(1, 2)
261
+
262
+ if actual_len < config['chunk_frames']:
263
+ pad_len = config['chunk_frames'] - actual_len
264
+ chunk_in = torch.nn.functional.pad(chunk_t, (0, 0, 0, pad_len))
265
+ else:
266
+ chunk_in = chunk_t[:, :config['chunk_frames'], :]
267
+
268
+ curr_spk_len = state.spkcache.shape[1]
269
+ curr_fifo_len = state.fifo.shape[1]
270
+
271
+ current_spkcache = state.spkcache
272
+ if curr_spk_len < config['spkcache_input_len']:
273
+ current_spkcache = torch.nn.functional.pad(
274
+ current_spkcache, (0, 0, 0, config['spkcache_input_len'] - curr_spk_len)
275
+ )
276
+ elif curr_spk_len > config['spkcache_input_len']:
277
+ current_spkcache = current_spkcache[:, :config['spkcache_input_len'], :]
278
+
279
+ current_fifo = state.fifo
280
+ if curr_fifo_len < config['fifo_input_len']:
281
+ current_fifo = torch.nn.functional.pad(
282
+ current_fifo, (0, 0, 0, config['fifo_input_len'] - curr_fifo_len)
283
+ )
284
+ elif curr_fifo_len > config['fifo_input_len']:
285
+ current_fifo = current_fifo[:, :config['fifo_input_len'], :]
286
+
287
+ coreml_inputs = {
288
+ "chunk": chunk_in.numpy().astype(np.float32),
289
+ "chunk_lengths": np.array([actual_len], dtype=np.int32),
290
+ "spkcache": current_spkcache.numpy().astype(np.float32),
291
+ "spkcache_lengths": np.array([curr_spk_len], dtype=np.int32),
292
+ "fifo": current_fifo.numpy().astype(np.float32),
293
+ "fifo_lengths": np.array([curr_fifo_len], dtype=np.int32)
294
+ }
295
+
296
+ coreml_out = main_model.predict(coreml_inputs)
297
+
298
+ pred_logits = torch.from_numpy(coreml_out["speaker_preds"])
299
+ chunk_embs = torch.from_numpy(coreml_out["chunk_pre_encoder_embs"])
300
+ chunk_emb_len = int(coreml_out["chunk_pre_encoder_lengths"][0])
301
+
302
+ chunk_embs = chunk_embs[:, :chunk_emb_len, :]
303
+
304
+ lc = round(left_offset / subsampling_factor)
305
+ rc = math.ceil(right_offset / subsampling_factor)
306
+
307
+ state, chunk_probs = modules.streaming_update(
308
+ streaming_state=state,
309
+ chunk=chunk_embs,
310
+ preds=pred_logits,
311
+ lc=lc,
312
+ rc=rc
313
+ )
314
+
315
+ all_preds.append(chunk_probs)
316
+ stt_feat = end_feat
317
+ chunk_idx += 1
318
+
319
+ if len(all_preds) > 0:
320
+ return torch.cat(all_preds, dim=1)
321
+ return None
322
+
323
+
324
+ def validate(model_name, coreml_dir, audio_path):
325
+ """
326
+ Validate true streaming against NeMo preprocessing.
327
+ """
328
+ print("=" * 70)
329
+ print("VALIDATION: True Streaming vs NeMo Preprocessing")
330
+ print("=" * 70)
331
+
332
+ # Load NeMo model
333
+ print(f"\nLoading NeMo Model: {model_name}")
334
+ nemo_model = SortformerEncLabelModel.from_pretrained(model_name, map_location="cpu")
335
+ nemo_model.eval()
336
+
337
+ # Apply config
338
+ modules = nemo_model.sortformer_modules
339
+ modules.chunk_len = CONFIG['chunk_len']
340
+ modules.chunk_right_context = CONFIG['chunk_right_context']
341
+ modules.chunk_left_context = CONFIG['chunk_left_context']
342
+ modules.fifo_len = CONFIG['fifo_len']
343
+ modules.spkcache_len = CONFIG['spkcache_len']
344
+ modules.spkcache_update_period = CONFIG['spkcache_update_period']
345
+
346
+ # Disable dither and pad_to
347
+ if hasattr(nemo_model.preprocessor, 'featurizer'):
348
+ nemo_model.preprocessor.featurizer.dither = 0.0
349
+ nemo_model.preprocessor.featurizer.pad_to = 0
350
+
351
+ print(f"Config: chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, "
352
+ f"right_ctx={modules.chunk_right_context}")
353
+
354
+ # Load CoreML models
355
+ print(f"Loading CoreML Models from {coreml_dir}...")
356
+ preproc_model = ct.models.MLModel(
357
+ os.path.join(coreml_dir, "SortformerPreprocessor.mlpackage"),
358
+ compute_units=ct.ComputeUnit.CPU_ONLY
359
+ )
360
+ main_model = ct.models.MLModel(
361
+ os.path.join(coreml_dir, "Sortformer16.mlpackage"),
362
+ compute_units=ct.ComputeUnit.CPU_ONLY
363
+ )
364
+
365
+ # Reference
366
+ print("\n" + "=" * 70)
367
+ print("TEST 1: NeMo Preprocessing + CoreML Inference (Reference)")
368
+ print("=" * 70)
369
+
370
+ ref_probs = run_reference(nemo_model, main_model, audio_path, CONFIG)
371
+ if ref_probs is not None:
372
+ ref_probs_np = ref_probs.squeeze(0).detach().cpu().numpy()
373
+ print(f"Reference Probs Shape: {ref_probs_np.shape}")
374
+ else:
375
+ print("Reference inference failed!")
376
+ return
377
+
378
+ # True streaming
379
+ print("\n" + "=" * 70)
380
+ print("TEST 2: True Streaming (Audio → CoreML Preproc → CoreML Main)")
381
+ print("=" * 70)
382
+
383
+ streaming_probs = run_true_streaming(nemo_model, preproc_model, main_model, audio_path, CONFIG)
384
+
385
+ if streaming_probs is not None:
386
+ streaming_probs_np = streaming_probs.squeeze(0).detach().cpu().numpy()
387
+ print(f"Streaming Probs Shape: {streaming_probs_np.shape}")
388
+
389
+ # Compare
390
+ min_len = min(ref_probs_np.shape[0], streaming_probs_np.shape[0])
391
+ diff = np.abs(ref_probs_np[:min_len] - streaming_probs_np[:min_len])
392
+ print(f"\nLength: ref={ref_probs_np.shape[0]}, streaming={streaming_probs_np.shape[0]}")
393
+ print(f"Mean Absolute Error: {np.mean(diff):.8f}")
394
+ print(f"Max Absolute Error: {np.max(diff):.8f}")
395
+
396
+ if np.max(diff) < 0.01:
397
+ print("\n✅ SUCCESS: True streaming matches reference!")
398
+ else:
399
+ print("\n⚠️ Errors exceed tolerance")
400
+ else:
401
+ print("True streaming inference produced no output!")
402
+
403
+
404
+ if __name__ == "__main__":
405
+ parser = argparse.ArgumentParser()
406
+ parser.add_argument("--model_name", default="nvidia/diar_streaming_sortformer_4spk-v2.1")
407
+ parser.add_argument("--coreml_dir", default="coreml_models")
408
+ parser.add_argument("--audio_path", default="audio.wav")
409
+ args = parser.parse_args()
410
+
411
+ validate(args.model_name, args.coreml_dir, args.audio_path)