alexwengg's picture
Upload 26 files
f9a579a verified
#!/usr/bin/env python3
"""
Streaming Sortformer CoreML Inference
This script demonstrates how to use the CoreML-converted NVIDIA Streaming Sortformer
model for real-time speaker diarization on Apple Silicon.
Original model: nvidia/diar_streaming_sortformer_4spk-v2.1
"""
import os
import numpy as np
import coremltools as ct
# Configuration matching NVIDIA's streaming settings
CONFIG = {
"chunk_len": 6, # Core chunk length in encoder frames
"chunk_left_context": 1, # Left context frames
"chunk_right_context": 7, # Right context frames
"fifo_len": 188, # FIFO buffer length
"spkcache_len": 188, # Speaker cache length
"spkcache_update_period": 144,
"subsampling_factor": 8, # Mel frames per encoder frame
"n_speakers": 4, # Max speakers
"sample_rate": 16000,
"mel_features": 128,
}
class SortformerCoreML:
"""CoreML Streaming Sortformer Diarizer"""
def __init__(self, model_dir: str = ".", compute_units: str = "CPU_ONLY"):
"""
Initialize the CoreML Sortformer pipeline.
Args:
model_dir: Directory containing the .mlpackage files
compute_units: "CPU_ONLY", "CPU_AND_GPU", or "ALL"
"""
cu = getattr(ct.ComputeUnit, compute_units, ct.ComputeUnit.CPU_ONLY)
# Load models
self.preprocessor = ct.models.MLModel(
os.path.join(model_dir, "Pipeline_Preprocessor.mlpackage"),
compute_units=cu
)
self.pre_encoder = ct.models.MLModel(
os.path.join(model_dir, "Pipeline_PreEncoder.mlpackage"),
compute_units=cu
)
self.head = ct.models.MLModel(
os.path.join(model_dir, "Pipeline_Head_Fixed.mlpackage"),
compute_units=cu
)
# Initialize state buffers
self.reset_state()
def reset_state(self):
"""Reset streaming state for new audio session."""
self.spkcache = np.zeros((1, CONFIG["spkcache_len"], 512), dtype=np.float32)
self.fifo = np.zeros((1, CONFIG["fifo_len"], 512), dtype=np.float32)
self.spkcache_len = 0
self.fifo_len = 0
self.chunk_idx = 0
def process_chunk(self, mel_features: np.ndarray, chunk_length: int) -> np.ndarray:
"""
Process a single chunk of mel features.
Args:
mel_features: Mel spectrogram chunk [1, T, 128] where T <= 112
chunk_length: Actual valid length (before padding)
Returns:
Speaker predictions [num_frames, 4] with probabilities for each speaker
"""
# Pad to 112 if needed
if mel_features.shape[1] < 112:
pad_len = 112 - mel_features.shape[1]
mel_features = np.pad(mel_features, ((0, 0), (0, pad_len), (0, 0)))
# Run PreEncoder
pre_out = self.pre_encoder.predict({
"chunk": mel_features.astype(np.float32),
"chunk_lengths": np.array([chunk_length], dtype=np.int32),
"spkcache": self.spkcache,
"spkcache_lengths": np.array([self.spkcache_len], dtype=np.int32),
"fifo": self.fifo,
"fifo_lengths": np.array([self.fifo_len], dtype=np.int32)
})
# Run Head
head_out = self.head.predict({
"pre_encoder_embs": pre_out["pre_encoder_embs"],
"pre_encoder_lengths": pre_out["pre_encoder_lengths"],
"chunk_embs_in": pre_out["chunk_embs_in"],
"chunk_lens_in": pre_out["chunk_lens_in"]
})
# Extract predictions for this chunk
emb_len = int(head_out["chunk_pre_encoder_lengths"][0])
lc = 0 if self.chunk_idx == 0 else 1 # Left context
rc = CONFIG["chunk_right_context"]
chunk_pred_len = emb_len - lc - rc
pred_offset = self.spkcache_len + self.fifo_len + lc
predictions = head_out["speaker_preds"][0, pred_offset:pred_offset + chunk_pred_len, :]
# Update state (simplified - full implementation needs NeMo's streaming_update logic)
self._update_state(pre_out, emb_len)
self.chunk_idx += 1
return predictions
def _update_state(self, pre_out, emb_len):
"""Update spkcache and fifo state buffers."""
# Get new chunk embeddings
new_embs = pre_out["chunk_embs_in"][0, :emb_len, :]
# Add to fifo
if self.fifo_len + emb_len <= CONFIG["fifo_len"]:
self.fifo[0, self.fifo_len:self.fifo_len + emb_len, :] = new_embs
self.fifo_len += emb_len
else:
# FIFO overflow - move to spkcache
overflow = self.fifo_len + emb_len - CONFIG["fifo_len"]
# Move overflow from fifo to spkcache
if self.spkcache_len + overflow <= CONFIG["spkcache_len"]:
self.spkcache[0, self.spkcache_len:self.spkcache_len + overflow, :] = \
self.fifo[0, :overflow, :]
self.spkcache_len += overflow
# Shift fifo and add new
self.fifo[0, :self.fifo_len - overflow, :] = self.fifo[0, overflow:self.fifo_len, :]
self.fifo_len -= overflow
self.fifo[0, self.fifo_len:self.fifo_len + emb_len, :] = new_embs
self.fifo_len += emb_len
def process_audio(audio_path: str, model_dir: str = ".") -> list:
"""
Process an audio file and return diarization results.
Args:
audio_path: Path to audio file (16kHz mono WAV)
model_dir: Directory containing CoreML models
Returns:
List of (start_time, end_time, speaker_id) tuples
"""
import torchaudio
import torch
# Load audio
waveform, sr = torchaudio.load(audio_path)
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Initialize model
model = SortformerCoreML(model_dir)
# Compute mel spectrogram using NeMo-compatible settings
# (You may need to use the Pipeline_Preprocessor or native mel computation)
# Process in chunks and collect predictions
# ... (implementation depends on your mel spectrogram computation)
print(f"Loaded audio: {waveform.shape}, {sr}Hz")
print("Processing... (implement chunking logic)")
return []
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python inference.py <audio_file.wav>")
print("\nThis script requires:")
print(" - Pipeline_Preprocessor.mlpackage")
print(" - Pipeline_PreEncoder.mlpackage")
print(" - Pipeline_Head_Fixed.mlpackage")
sys.exit(1)
results = process_audio(sys.argv[1])
for start, end, speaker in results:
print(f"[{start:.2f}s - {end:.2f}s] Speaker {speaker}")