#!/usr/bin/env python3 """ Streaming Sortformer CoreML Inference This script demonstrates how to use the CoreML-converted NVIDIA Streaming Sortformer model for real-time speaker diarization on Apple Silicon. Original model: nvidia/diar_streaming_sortformer_4spk-v2.1 """ import os import numpy as np import coremltools as ct # Configuration matching NVIDIA's streaming settings CONFIG = { "chunk_len": 6, # Core chunk length in encoder frames "chunk_left_context": 1, # Left context frames "chunk_right_context": 7, # Right context frames "fifo_len": 188, # FIFO buffer length "spkcache_len": 188, # Speaker cache length "spkcache_update_period": 144, "subsampling_factor": 8, # Mel frames per encoder frame "n_speakers": 4, # Max speakers "sample_rate": 16000, "mel_features": 128, } class SortformerCoreML: """CoreML Streaming Sortformer Diarizer""" def __init__(self, model_dir: str = ".", compute_units: str = "CPU_ONLY"): """ Initialize the CoreML Sortformer pipeline. Args: model_dir: Directory containing the .mlpackage files compute_units: "CPU_ONLY", "CPU_AND_GPU", or "ALL" """ cu = getattr(ct.ComputeUnit, compute_units, ct.ComputeUnit.CPU_ONLY) # Load models self.preprocessor = ct.models.MLModel( os.path.join(model_dir, "Pipeline_Preprocessor.mlpackage"), compute_units=cu ) self.pre_encoder = ct.models.MLModel( os.path.join(model_dir, "Pipeline_PreEncoder.mlpackage"), compute_units=cu ) self.head = ct.models.MLModel( os.path.join(model_dir, "Pipeline_Head_Fixed.mlpackage"), compute_units=cu ) # Initialize state buffers self.reset_state() def reset_state(self): """Reset streaming state for new audio session.""" self.spkcache = np.zeros((1, CONFIG["spkcache_len"], 512), dtype=np.float32) self.fifo = np.zeros((1, CONFIG["fifo_len"], 512), dtype=np.float32) self.spkcache_len = 0 self.fifo_len = 0 self.chunk_idx = 0 def process_chunk(self, mel_features: np.ndarray, chunk_length: int) -> np.ndarray: """ Process a single chunk of mel features. Args: mel_features: Mel spectrogram chunk [1, T, 128] where T <= 112 chunk_length: Actual valid length (before padding) Returns: Speaker predictions [num_frames, 4] with probabilities for each speaker """ # Pad to 112 if needed if mel_features.shape[1] < 112: pad_len = 112 - mel_features.shape[1] mel_features = np.pad(mel_features, ((0, 0), (0, pad_len), (0, 0))) # Run PreEncoder pre_out = self.pre_encoder.predict({ "chunk": mel_features.astype(np.float32), "chunk_lengths": np.array([chunk_length], dtype=np.int32), "spkcache": self.spkcache, "spkcache_lengths": np.array([self.spkcache_len], dtype=np.int32), "fifo": self.fifo, "fifo_lengths": np.array([self.fifo_len], dtype=np.int32) }) # Run Head head_out = self.head.predict({ "pre_encoder_embs": pre_out["pre_encoder_embs"], "pre_encoder_lengths": pre_out["pre_encoder_lengths"], "chunk_embs_in": pre_out["chunk_embs_in"], "chunk_lens_in": pre_out["chunk_lens_in"] }) # Extract predictions for this chunk emb_len = int(head_out["chunk_pre_encoder_lengths"][0]) lc = 0 if self.chunk_idx == 0 else 1 # Left context rc = CONFIG["chunk_right_context"] chunk_pred_len = emb_len - lc - rc pred_offset = self.spkcache_len + self.fifo_len + lc predictions = head_out["speaker_preds"][0, pred_offset:pred_offset + chunk_pred_len, :] # Update state (simplified - full implementation needs NeMo's streaming_update logic) self._update_state(pre_out, emb_len) self.chunk_idx += 1 return predictions def _update_state(self, pre_out, emb_len): """Update spkcache and fifo state buffers.""" # Get new chunk embeddings new_embs = pre_out["chunk_embs_in"][0, :emb_len, :] # Add to fifo if self.fifo_len + emb_len <= CONFIG["fifo_len"]: self.fifo[0, self.fifo_len:self.fifo_len + emb_len, :] = new_embs self.fifo_len += emb_len else: # FIFO overflow - move to spkcache overflow = self.fifo_len + emb_len - CONFIG["fifo_len"] # Move overflow from fifo to spkcache if self.spkcache_len + overflow <= CONFIG["spkcache_len"]: self.spkcache[0, self.spkcache_len:self.spkcache_len + overflow, :] = \ self.fifo[0, :overflow, :] self.spkcache_len += overflow # Shift fifo and add new self.fifo[0, :self.fifo_len - overflow, :] = self.fifo[0, overflow:self.fifo_len, :] self.fifo_len -= overflow self.fifo[0, self.fifo_len:self.fifo_len + emb_len, :] = new_embs self.fifo_len += emb_len def process_audio(audio_path: str, model_dir: str = ".") -> list: """ Process an audio file and return diarization results. Args: audio_path: Path to audio file (16kHz mono WAV) model_dir: Directory containing CoreML models Returns: List of (start_time, end_time, speaker_id) tuples """ import torchaudio import torch # Load audio waveform, sr = torchaudio.load(audio_path) if sr != 16000: waveform = torchaudio.functional.resample(waveform, sr, 16000) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) # Initialize model model = SortformerCoreML(model_dir) # Compute mel spectrogram using NeMo-compatible settings # (You may need to use the Pipeline_Preprocessor or native mel computation) # Process in chunks and collect predictions # ... (implementation depends on your mel spectrogram computation) print(f"Loaded audio: {waveform.shape}, {sr}Hz") print("Processing... (implement chunking logic)") return [] if __name__ == "__main__": import sys if len(sys.argv) < 2: print("Usage: python inference.py ") print("\nThis script requires:") print(" - Pipeline_Preprocessor.mlpackage") print(" - Pipeline_PreEncoder.mlpackage") print(" - Pipeline_Head_Fixed.mlpackage") sys.exit(1) results = process_audio(sys.argv[1]) for start, end, speaker in results: print(f"[{start:.2f}s - {end:.2f}s] Speaker {speaker}")