File size: 6,843 Bytes
f9a579a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
#!/usr/bin/env python3
"""
Streaming Sortformer CoreML Inference
This script demonstrates how to use the CoreML-converted NVIDIA Streaming Sortformer
model for real-time speaker diarization on Apple Silicon.
Original model: nvidia/diar_streaming_sortformer_4spk-v2.1
"""
import os
import numpy as np
import coremltools as ct
# Configuration matching NVIDIA's streaming settings
CONFIG = {
"chunk_len": 6, # Core chunk length in encoder frames
"chunk_left_context": 1, # Left context frames
"chunk_right_context": 7, # Right context frames
"fifo_len": 188, # FIFO buffer length
"spkcache_len": 188, # Speaker cache length
"spkcache_update_period": 144,
"subsampling_factor": 8, # Mel frames per encoder frame
"n_speakers": 4, # Max speakers
"sample_rate": 16000,
"mel_features": 128,
}
class SortformerCoreML:
"""CoreML Streaming Sortformer Diarizer"""
def __init__(self, model_dir: str = ".", compute_units: str = "CPU_ONLY"):
"""
Initialize the CoreML Sortformer pipeline.
Args:
model_dir: Directory containing the .mlpackage files
compute_units: "CPU_ONLY", "CPU_AND_GPU", or "ALL"
"""
cu = getattr(ct.ComputeUnit, compute_units, ct.ComputeUnit.CPU_ONLY)
# Load models
self.preprocessor = ct.models.MLModel(
os.path.join(model_dir, "Pipeline_Preprocessor.mlpackage"),
compute_units=cu
)
self.pre_encoder = ct.models.MLModel(
os.path.join(model_dir, "Pipeline_PreEncoder.mlpackage"),
compute_units=cu
)
self.head = ct.models.MLModel(
os.path.join(model_dir, "Pipeline_Head_Fixed.mlpackage"),
compute_units=cu
)
# Initialize state buffers
self.reset_state()
def reset_state(self):
"""Reset streaming state for new audio session."""
self.spkcache = np.zeros((1, CONFIG["spkcache_len"], 512), dtype=np.float32)
self.fifo = np.zeros((1, CONFIG["fifo_len"], 512), dtype=np.float32)
self.spkcache_len = 0
self.fifo_len = 0
self.chunk_idx = 0
def process_chunk(self, mel_features: np.ndarray, chunk_length: int) -> np.ndarray:
"""
Process a single chunk of mel features.
Args:
mel_features: Mel spectrogram chunk [1, T, 128] where T <= 112
chunk_length: Actual valid length (before padding)
Returns:
Speaker predictions [num_frames, 4] with probabilities for each speaker
"""
# Pad to 112 if needed
if mel_features.shape[1] < 112:
pad_len = 112 - mel_features.shape[1]
mel_features = np.pad(mel_features, ((0, 0), (0, pad_len), (0, 0)))
# Run PreEncoder
pre_out = self.pre_encoder.predict({
"chunk": mel_features.astype(np.float32),
"chunk_lengths": np.array([chunk_length], dtype=np.int32),
"spkcache": self.spkcache,
"spkcache_lengths": np.array([self.spkcache_len], dtype=np.int32),
"fifo": self.fifo,
"fifo_lengths": np.array([self.fifo_len], dtype=np.int32)
})
# Run Head
head_out = self.head.predict({
"pre_encoder_embs": pre_out["pre_encoder_embs"],
"pre_encoder_lengths": pre_out["pre_encoder_lengths"],
"chunk_embs_in": pre_out["chunk_embs_in"],
"chunk_lens_in": pre_out["chunk_lens_in"]
})
# Extract predictions for this chunk
emb_len = int(head_out["chunk_pre_encoder_lengths"][0])
lc = 0 if self.chunk_idx == 0 else 1 # Left context
rc = CONFIG["chunk_right_context"]
chunk_pred_len = emb_len - lc - rc
pred_offset = self.spkcache_len + self.fifo_len + lc
predictions = head_out["speaker_preds"][0, pred_offset:pred_offset + chunk_pred_len, :]
# Update state (simplified - full implementation needs NeMo's streaming_update logic)
self._update_state(pre_out, emb_len)
self.chunk_idx += 1
return predictions
def _update_state(self, pre_out, emb_len):
"""Update spkcache and fifo state buffers."""
# Get new chunk embeddings
new_embs = pre_out["chunk_embs_in"][0, :emb_len, :]
# Add to fifo
if self.fifo_len + emb_len <= CONFIG["fifo_len"]:
self.fifo[0, self.fifo_len:self.fifo_len + emb_len, :] = new_embs
self.fifo_len += emb_len
else:
# FIFO overflow - move to spkcache
overflow = self.fifo_len + emb_len - CONFIG["fifo_len"]
# Move overflow from fifo to spkcache
if self.spkcache_len + overflow <= CONFIG["spkcache_len"]:
self.spkcache[0, self.spkcache_len:self.spkcache_len + overflow, :] = \
self.fifo[0, :overflow, :]
self.spkcache_len += overflow
# Shift fifo and add new
self.fifo[0, :self.fifo_len - overflow, :] = self.fifo[0, overflow:self.fifo_len, :]
self.fifo_len -= overflow
self.fifo[0, self.fifo_len:self.fifo_len + emb_len, :] = new_embs
self.fifo_len += emb_len
def process_audio(audio_path: str, model_dir: str = ".") -> list:
"""
Process an audio file and return diarization results.
Args:
audio_path: Path to audio file (16kHz mono WAV)
model_dir: Directory containing CoreML models
Returns:
List of (start_time, end_time, speaker_id) tuples
"""
import torchaudio
import torch
# Load audio
waveform, sr = torchaudio.load(audio_path)
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Initialize model
model = SortformerCoreML(model_dir)
# Compute mel spectrogram using NeMo-compatible settings
# (You may need to use the Pipeline_Preprocessor or native mel computation)
# Process in chunks and collect predictions
# ... (implementation depends on your mel spectrogram computation)
print(f"Loaded audio: {waveform.shape}, {sr}Hz")
print("Processing... (implement chunking logic)")
return []
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python inference.py <audio_file.wav>")
print("\nThis script requires:")
print(" - Pipeline_Preprocessor.mlpackage")
print(" - Pipeline_PreEncoder.mlpackage")
print(" - Pipeline_Head_Fixed.mlpackage")
sys.exit(1)
results = process_audio(sys.argv[1])
for start, end, speaker in results:
print(f"[{start:.2f}s - {end:.2f}s] Speaker {speaker}")
|