File size: 6,843 Bytes
f9a579a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#!/usr/bin/env python3
"""
Streaming Sortformer CoreML Inference

This script demonstrates how to use the CoreML-converted NVIDIA Streaming Sortformer
model for real-time speaker diarization on Apple Silicon.

Original model: nvidia/diar_streaming_sortformer_4spk-v2.1
"""

import os
import numpy as np
import coremltools as ct

# Configuration matching NVIDIA's streaming settings
CONFIG = {
    "chunk_len": 6,              # Core chunk length in encoder frames
    "chunk_left_context": 1,     # Left context frames
    "chunk_right_context": 7,    # Right context frames
    "fifo_len": 188,             # FIFO buffer length
    "spkcache_len": 188,         # Speaker cache length
    "spkcache_update_period": 144,
    "subsampling_factor": 8,     # Mel frames per encoder frame
    "n_speakers": 4,             # Max speakers
    "sample_rate": 16000,
    "mel_features": 128,
}


class SortformerCoreML:
    """CoreML Streaming Sortformer Diarizer"""

    def __init__(self, model_dir: str = ".", compute_units: str = "CPU_ONLY"):
        """
        Initialize the CoreML Sortformer pipeline.

        Args:
            model_dir: Directory containing the .mlpackage files
            compute_units: "CPU_ONLY", "CPU_AND_GPU", or "ALL"
        """
        cu = getattr(ct.ComputeUnit, compute_units, ct.ComputeUnit.CPU_ONLY)

        # Load models
        self.preprocessor = ct.models.MLModel(
            os.path.join(model_dir, "Pipeline_Preprocessor.mlpackage"),
            compute_units=cu
        )
        self.pre_encoder = ct.models.MLModel(
            os.path.join(model_dir, "Pipeline_PreEncoder.mlpackage"),
            compute_units=cu
        )
        self.head = ct.models.MLModel(
            os.path.join(model_dir, "Pipeline_Head_Fixed.mlpackage"),
            compute_units=cu
        )

        # Initialize state buffers
        self.reset_state()

    def reset_state(self):
        """Reset streaming state for new audio session."""
        self.spkcache = np.zeros((1, CONFIG["spkcache_len"], 512), dtype=np.float32)
        self.fifo = np.zeros((1, CONFIG["fifo_len"], 512), dtype=np.float32)
        self.spkcache_len = 0
        self.fifo_len = 0
        self.chunk_idx = 0

    def process_chunk(self, mel_features: np.ndarray, chunk_length: int) -> np.ndarray:
        """
        Process a single chunk of mel features.

        Args:
            mel_features: Mel spectrogram chunk [1, T, 128] where T <= 112
            chunk_length: Actual valid length (before padding)

        Returns:
            Speaker predictions [num_frames, 4] with probabilities for each speaker
        """
        # Pad to 112 if needed
        if mel_features.shape[1] < 112:
            pad_len = 112 - mel_features.shape[1]
            mel_features = np.pad(mel_features, ((0, 0), (0, pad_len), (0, 0)))

        # Run PreEncoder
        pre_out = self.pre_encoder.predict({
            "chunk": mel_features.astype(np.float32),
            "chunk_lengths": np.array([chunk_length], dtype=np.int32),
            "spkcache": self.spkcache,
            "spkcache_lengths": np.array([self.spkcache_len], dtype=np.int32),
            "fifo": self.fifo,
            "fifo_lengths": np.array([self.fifo_len], dtype=np.int32)
        })

        # Run Head
        head_out = self.head.predict({
            "pre_encoder_embs": pre_out["pre_encoder_embs"],
            "pre_encoder_lengths": pre_out["pre_encoder_lengths"],
            "chunk_embs_in": pre_out["chunk_embs_in"],
            "chunk_lens_in": pre_out["chunk_lens_in"]
        })

        # Extract predictions for this chunk
        emb_len = int(head_out["chunk_pre_encoder_lengths"][0])
        lc = 0 if self.chunk_idx == 0 else 1  # Left context
        rc = CONFIG["chunk_right_context"]
        chunk_pred_len = emb_len - lc - rc

        pred_offset = self.spkcache_len + self.fifo_len + lc
        predictions = head_out["speaker_preds"][0, pred_offset:pred_offset + chunk_pred_len, :]

        # Update state (simplified - full implementation needs NeMo's streaming_update logic)
        self._update_state(pre_out, emb_len)

        self.chunk_idx += 1
        return predictions

    def _update_state(self, pre_out, emb_len):
        """Update spkcache and fifo state buffers."""
        # Get new chunk embeddings
        new_embs = pre_out["chunk_embs_in"][0, :emb_len, :]

        # Add to fifo
        if self.fifo_len + emb_len <= CONFIG["fifo_len"]:
            self.fifo[0, self.fifo_len:self.fifo_len + emb_len, :] = new_embs
            self.fifo_len += emb_len
        else:
            # FIFO overflow - move to spkcache
            overflow = self.fifo_len + emb_len - CONFIG["fifo_len"]

            # Move overflow from fifo to spkcache
            if self.spkcache_len + overflow <= CONFIG["spkcache_len"]:
                self.spkcache[0, self.spkcache_len:self.spkcache_len + overflow, :] = \
                    self.fifo[0, :overflow, :]
                self.spkcache_len += overflow

            # Shift fifo and add new
            self.fifo[0, :self.fifo_len - overflow, :] = self.fifo[0, overflow:self.fifo_len, :]
            self.fifo_len -= overflow
            self.fifo[0, self.fifo_len:self.fifo_len + emb_len, :] = new_embs
            self.fifo_len += emb_len


def process_audio(audio_path: str, model_dir: str = ".") -> list:
    """
    Process an audio file and return diarization results.

    Args:
        audio_path: Path to audio file (16kHz mono WAV)
        model_dir: Directory containing CoreML models

    Returns:
        List of (start_time, end_time, speaker_id) tuples
    """
    import torchaudio
    import torch

    # Load audio
    waveform, sr = torchaudio.load(audio_path)
    if sr != 16000:
        waveform = torchaudio.functional.resample(waveform, sr, 16000)
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # Initialize model
    model = SortformerCoreML(model_dir)

    # Compute mel spectrogram using NeMo-compatible settings
    # (You may need to use the Pipeline_Preprocessor or native mel computation)

    # Process in chunks and collect predictions
    # ... (implementation depends on your mel spectrogram computation)

    print(f"Loaded audio: {waveform.shape}, {sr}Hz")
    print("Processing... (implement chunking logic)")

    return []


if __name__ == "__main__":
    import sys

    if len(sys.argv) < 2:
        print("Usage: python inference.py <audio_file.wav>")
        print("\nThis script requires:")
        print("  - Pipeline_Preprocessor.mlpackage")
        print("  - Pipeline_PreEncoder.mlpackage")
        print("  - Pipeline_Head_Fixed.mlpackage")
        sys.exit(1)

    results = process_audio(sys.argv[1])
    for start, end, speaker in results:
        print(f"[{start:.2f}s - {end:.2f}s] Speaker {speaker}")