|
|
|
|
|
""" |
|
|
Streaming Sortformer CoreML Inference |
|
|
|
|
|
This script demonstrates how to use the CoreML-converted NVIDIA Streaming Sortformer |
|
|
model for real-time speaker diarization on Apple Silicon. |
|
|
|
|
|
Original model: nvidia/diar_streaming_sortformer_4spk-v2.1 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import numpy as np |
|
|
import coremltools as ct |
|
|
|
|
|
|
|
|
CONFIG = { |
|
|
"chunk_len": 6, |
|
|
"chunk_left_context": 1, |
|
|
"chunk_right_context": 7, |
|
|
"fifo_len": 188, |
|
|
"spkcache_len": 188, |
|
|
"spkcache_update_period": 144, |
|
|
"subsampling_factor": 8, |
|
|
"n_speakers": 4, |
|
|
"sample_rate": 16000, |
|
|
"mel_features": 128, |
|
|
} |
|
|
|
|
|
|
|
|
class SortformerCoreML: |
|
|
"""CoreML Streaming Sortformer Diarizer""" |
|
|
|
|
|
def __init__(self, model_dir: str = ".", compute_units: str = "CPU_ONLY"): |
|
|
""" |
|
|
Initialize the CoreML Sortformer pipeline. |
|
|
|
|
|
Args: |
|
|
model_dir: Directory containing the .mlpackage files |
|
|
compute_units: "CPU_ONLY", "CPU_AND_GPU", or "ALL" |
|
|
""" |
|
|
cu = getattr(ct.ComputeUnit, compute_units, ct.ComputeUnit.CPU_ONLY) |
|
|
|
|
|
|
|
|
self.preprocessor = ct.models.MLModel( |
|
|
os.path.join(model_dir, "Pipeline_Preprocessor.mlpackage"), |
|
|
compute_units=cu |
|
|
) |
|
|
self.pre_encoder = ct.models.MLModel( |
|
|
os.path.join(model_dir, "Pipeline_PreEncoder.mlpackage"), |
|
|
compute_units=cu |
|
|
) |
|
|
self.head = ct.models.MLModel( |
|
|
os.path.join(model_dir, "Pipeline_Head_Fixed.mlpackage"), |
|
|
compute_units=cu |
|
|
) |
|
|
|
|
|
|
|
|
self.reset_state() |
|
|
|
|
|
def reset_state(self): |
|
|
"""Reset streaming state for new audio session.""" |
|
|
self.spkcache = np.zeros((1, CONFIG["spkcache_len"], 512), dtype=np.float32) |
|
|
self.fifo = np.zeros((1, CONFIG["fifo_len"], 512), dtype=np.float32) |
|
|
self.spkcache_len = 0 |
|
|
self.fifo_len = 0 |
|
|
self.chunk_idx = 0 |
|
|
|
|
|
def process_chunk(self, mel_features: np.ndarray, chunk_length: int) -> np.ndarray: |
|
|
""" |
|
|
Process a single chunk of mel features. |
|
|
|
|
|
Args: |
|
|
mel_features: Mel spectrogram chunk [1, T, 128] where T <= 112 |
|
|
chunk_length: Actual valid length (before padding) |
|
|
|
|
|
Returns: |
|
|
Speaker predictions [num_frames, 4] with probabilities for each speaker |
|
|
""" |
|
|
|
|
|
if mel_features.shape[1] < 112: |
|
|
pad_len = 112 - mel_features.shape[1] |
|
|
mel_features = np.pad(mel_features, ((0, 0), (0, pad_len), (0, 0))) |
|
|
|
|
|
|
|
|
pre_out = self.pre_encoder.predict({ |
|
|
"chunk": mel_features.astype(np.float32), |
|
|
"chunk_lengths": np.array([chunk_length], dtype=np.int32), |
|
|
"spkcache": self.spkcache, |
|
|
"spkcache_lengths": np.array([self.spkcache_len], dtype=np.int32), |
|
|
"fifo": self.fifo, |
|
|
"fifo_lengths": np.array([self.fifo_len], dtype=np.int32) |
|
|
}) |
|
|
|
|
|
|
|
|
head_out = self.head.predict({ |
|
|
"pre_encoder_embs": pre_out["pre_encoder_embs"], |
|
|
"pre_encoder_lengths": pre_out["pre_encoder_lengths"], |
|
|
"chunk_embs_in": pre_out["chunk_embs_in"], |
|
|
"chunk_lens_in": pre_out["chunk_lens_in"] |
|
|
}) |
|
|
|
|
|
|
|
|
emb_len = int(head_out["chunk_pre_encoder_lengths"][0]) |
|
|
lc = 0 if self.chunk_idx == 0 else 1 |
|
|
rc = CONFIG["chunk_right_context"] |
|
|
chunk_pred_len = emb_len - lc - rc |
|
|
|
|
|
pred_offset = self.spkcache_len + self.fifo_len + lc |
|
|
predictions = head_out["speaker_preds"][0, pred_offset:pred_offset + chunk_pred_len, :] |
|
|
|
|
|
|
|
|
self._update_state(pre_out, emb_len) |
|
|
|
|
|
self.chunk_idx += 1 |
|
|
return predictions |
|
|
|
|
|
def _update_state(self, pre_out, emb_len): |
|
|
"""Update spkcache and fifo state buffers.""" |
|
|
|
|
|
new_embs = pre_out["chunk_embs_in"][0, :emb_len, :] |
|
|
|
|
|
|
|
|
if self.fifo_len + emb_len <= CONFIG["fifo_len"]: |
|
|
self.fifo[0, self.fifo_len:self.fifo_len + emb_len, :] = new_embs |
|
|
self.fifo_len += emb_len |
|
|
else: |
|
|
|
|
|
overflow = self.fifo_len + emb_len - CONFIG["fifo_len"] |
|
|
|
|
|
|
|
|
if self.spkcache_len + overflow <= CONFIG["spkcache_len"]: |
|
|
self.spkcache[0, self.spkcache_len:self.spkcache_len + overflow, :] = \ |
|
|
self.fifo[0, :overflow, :] |
|
|
self.spkcache_len += overflow |
|
|
|
|
|
|
|
|
self.fifo[0, :self.fifo_len - overflow, :] = self.fifo[0, overflow:self.fifo_len, :] |
|
|
self.fifo_len -= overflow |
|
|
self.fifo[0, self.fifo_len:self.fifo_len + emb_len, :] = new_embs |
|
|
self.fifo_len += emb_len |
|
|
|
|
|
|
|
|
def process_audio(audio_path: str, model_dir: str = ".") -> list: |
|
|
""" |
|
|
Process an audio file and return diarization results. |
|
|
|
|
|
Args: |
|
|
audio_path: Path to audio file (16kHz mono WAV) |
|
|
model_dir: Directory containing CoreML models |
|
|
|
|
|
Returns: |
|
|
List of (start_time, end_time, speaker_id) tuples |
|
|
""" |
|
|
import torchaudio |
|
|
import torch |
|
|
|
|
|
|
|
|
waveform, sr = torchaudio.load(audio_path) |
|
|
if sr != 16000: |
|
|
waveform = torchaudio.functional.resample(waveform, sr, 16000) |
|
|
if waveform.shape[0] > 1: |
|
|
waveform = waveform.mean(dim=0, keepdim=True) |
|
|
|
|
|
|
|
|
model = SortformerCoreML(model_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Loaded audio: {waveform.shape}, {sr}Hz") |
|
|
print("Processing... (implement chunking logic)") |
|
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
|
|
|
if len(sys.argv) < 2: |
|
|
print("Usage: python inference.py <audio_file.wav>") |
|
|
print("\nThis script requires:") |
|
|
print(" - Pipeline_Preprocessor.mlpackage") |
|
|
print(" - Pipeline_PreEncoder.mlpackage") |
|
|
print(" - Pipeline_Head_Fixed.mlpackage") |
|
|
sys.exit(1) |
|
|
|
|
|
results = process_audio(sys.argv[1]) |
|
|
for start, end, speaker in results: |
|
|
print(f"[{start:.2f}s - {end:.2f}s] Speaker {speaker}") |
|
|
|