EtMmohammedHafsati commited on
Commit
c2f1451
·
verified ·
1 Parent(s): d605fb2

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +317 -0
  2. doa_model.onnx +3 -0
  3. features.py +186 -0
  4. onnx_stream_microphone.py +796 -0
  5. silero_vad.onnx +3 -0
README.md ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ONNX Real-Time DOA Streaming
2
+
3
+ Real-time Direction of Arrival (DOA) detection using an ONNX model with microphone streaming. This script processes audio from a multi-channel microphone array (ReSpeaker) in real-time and displays detected sound source directions.
4
+
5
+ ## Overview
6
+
7
+ The script performs the following process:
8
+
9
+ 1. **Audio Capture**: Streams audio from a 6-channel microphone array (ReSpeaker)
10
+ 2. **Channel Selection**: Selects and reorders channels `[1, 4, 3, 2]` to get 4 channels
11
+ 3. **Feature Extraction**: Computes STFT features (magnitude, phase, cosine, sine) from the audio
12
+ 4. **ONNX Inference**: Runs the DOA model on GPU (CUDA) or CPU to get per-frame logits
13
+ 5. **Histogram Aggregation**: Aggregates logits into a circular histogram of azimuth angles
14
+ 6. **Peak Detection**: Finds peaks in the histogram to identify sound source directions
15
+ 7. **Event Gating**: Filters detections based on audio level changes and coherence
16
+ 8. **Visualization**: Displays detected directions on a polar plot in real-time
17
+
18
+ ## Prerequisites
19
+
20
+ ### Hardware
21
+ - **ReSpeaker 6-Mic Array** (or compatible multi-channel microphone)
22
+ - microphone:
23
+ positions:
24
+ - [0.0277, 0.0] # Mic 0: 0°
25
+ - [0.0, 0.0277] # Mic 1: 90°
26
+ - [-0.0277, 0.0] # Mic 2: 180°
27
+ - [0.0, -0.0277] # Mic 3: 270°
28
+ - **NVIDIA GPU** (optional, for faster inference)
29
+
30
+ ### Software Dependencies
31
+
32
+ Install the required packages:
33
+
34
+ ```bash
35
+ conda activate doaEnv
36
+ pip install onnxruntime-gpu # For GPU inference
37
+ # OR
38
+ pip install onnxruntime # For CPU-only inference
39
+
40
+ pip install pyaudio numpy matplotlib torch pyyaml
41
+ ```
42
+
43
+ ### ONNX Model
44
+
45
+ You need a converted ONNX model file. If you haven't converted your PyTorch model yet:
46
+
47
+ ```bash
48
+ python convert_to_onnx.py --checkpoint models/basic/2025-11-06_22-37-00-6a5fbc92/last.pt --output models/basic/2025-11-06_22-37-00-6a5fbc92/last.onnx
49
+ ```
50
+
51
+ ## Quick Start
52
+
53
+ ### 1. List Available Audio Devices
54
+
55
+ First, find your ReSpeaker device index:
56
+
57
+ ```bash
58
+ python onnx_stream_microphone.py --list-devices
59
+ ```
60
+
61
+ Look for a device named "ReSpeaker" or "Seeed" or containing "2886". Note the device index.
62
+
63
+ ### 2. Stop PulseAudio (Required)
64
+
65
+ On Linux, PulseAudio often locks the ALSA devices. You need to temporarily stop it:
66
+
67
+ ```bash
68
+ pulseaudio --kill
69
+ ```
70
+
71
+ **Note**: You can use the helper script `run_onnx_stream.sh` which automates this (see below).
72
+
73
+ ### 3. Run the Streaming Script
74
+
75
+ Basic usage:
76
+
77
+ ```bash
78
+ python onnx_stream_microphone.py \
79
+ --onnx models/basic/2025-11-06_22-37-00-6a5fbc92/last.onnx \
80
+ --device-index 9
81
+ ```
82
+
83
+ ### 4. Restart PulseAudio (After Stopping)
84
+
85
+ After you're done, restart PulseAudio:
86
+
87
+ ```bash
88
+ pulseaudio --start
89
+ ```
90
+
91
+ ## Using the Helper Script
92
+
93
+ A helper script automates PulseAudio management:
94
+
95
+ ```bash
96
+ chmod +x run_onnx_stream.sh
97
+ ./run_onnx_stream.sh --onnx models/basic/2025-11-06_22-37-00-6a5fbc92/last.onnx --device-index 9
98
+ ```
99
+
100
+ This script will:
101
+ 1. Stop PulseAudio
102
+ 2. Run the streaming script
103
+ 3. Restart PulseAudio when you exit (Ctrl+C)
104
+
105
+ ## Command-Line Arguments
106
+
107
+ ### Required Arguments
108
+
109
+ - `--onnx PATH`: Path to the ONNX model file
110
+
111
+ ### Audio Configuration
112
+
113
+ - `--device-index INT`: Audio device index (use `--list-devices` to find it)
114
+ - `--sample-rate INT`: Sample rate in Hz (default: 16000)
115
+ - `--window-ms INT`: Analysis window length in milliseconds (default: 200)
116
+ - `--hop-ms INT`: Hop size (overlap) in milliseconds (default: 100)
117
+ - `--chunk-size INT`: Audio buffer chunk size (default: 1600)
118
+ - `--cpu-only`: Use CPU only (disable GPU inference)
119
+ - `--list-devices`: List all available audio input devices and exit
120
+
121
+ ### Model Configuration
122
+
123
+ - `--config PATH`: Path to config.yaml (default: `configs/train.yaml`)
124
+
125
+ ### Histogram Detection Parameters
126
+
127
+ These control how DOA peaks are detected from the model logits:
128
+
129
+ - `--K INT`: Number of azimuth bins (default: 72, should match model)
130
+ - `--tau FLOAT`: Softmax temperature for histogram (default: 0.8)
131
+ - `--smooth-k INT`: Histogram smoothing kernel size (default: 1)
132
+ - `--min-peak-height FLOAT`: Minimum peak height threshold (default: 0.10)
133
+ - `--min-window-mass FLOAT`: Minimum window mass for peak validation (default: 0.24)
134
+ - `--min-sep-deg FLOAT`: Minimum angular separation between peaks in degrees (default: 20.0)
135
+ - `--min-active-ratio FLOAT`: Minimum active frame ratio (default: 0.20)
136
+ - `--max-sources INT`: Maximum number of sources to detect (default: 3)
137
+
138
+ ### Event Gate Parameters
139
+
140
+ These control when detections are considered valid (filtering noise):
141
+
142
+ - `--level-delta-on-db FLOAT`: Level increase threshold to open gate (default: 2.5)
143
+ - `--level-delta-off-db FLOAT`: Level decrease threshold to close gate (default: 1.0)
144
+ - `--level-min-dbfs FLOAT`: Minimum audio level in dBFS (default: -60.0)
145
+ - `--level-ema-alpha FLOAT`: Exponential moving average alpha for level tracking (default: 0.05)
146
+ - `--event-hold-ms INT`: Minimum time to keep gate open after detection (default: 300)
147
+ - `--min-R-clip FLOAT`: Minimum R_clip (coherence measure) to open gate (default: 0.18)
148
+ - `--event-refractory-ms INT`: Minimum time between gate state changes (default: 120)
149
+
150
+ ### Onset Detection Parameters
151
+
152
+ - `--onset-alpha FLOAT`: EMA alpha for spectral flux tracking (default: 0.05)
153
+
154
+ ## Example with Custom Parameters
155
+
156
+ ```bash
157
+ python onnx_stream_microphone.py \
158
+ --onnx doa_model.onnx \
159
+ --device-index 9 \
160
+ --window-ms 400 \
161
+ --hop-ms 100 \
162
+ --K 72 \
163
+ --max-sources 2 \
164
+ --tau 0.8 \
165
+ --smooth-k 1 \
166
+ --min-peak-height 0.08 \
167
+ --min-window-mass 0.16 \
168
+ --min-sep-deg 22.5 \
169
+ --min-active-ratio 0.15 \
170
+ --level-delta-on-db 4.0 \
171
+ --level-delta-off-db 1.5 \
172
+ --level-min-dbfs -55.0 \
173
+ --level-ema-alpha 0.05 \
174
+ --event-hold-ms 320 \
175
+ --event-refractory-ms 200 \
176
+ --min-R-clip 0.30 \
177
+ --onset-alpha 0.05
178
+ ```
179
+
180
+ ## Understanding the Output
181
+
182
+ ### Console Output
183
+
184
+ Each line shows:
185
+
186
+ ```
187
+ [ 12.34s] LVL= -45.2 dBFS diff=+3.5 | FLUXz=2.10 COH=0.75 | GATE=OPEN | MODEL= 12.3ms HIST= 2.1ms | DOA(R=0.45, n=2) [45°, 180°]
188
+ ```
189
+
190
+ - `[time]`: Elapsed time in seconds
191
+ - `LVL`: Audio level in dBFS
192
+ - `diff`: Level difference from background (dB)
193
+ - `FLUXz`: Spectral flux z-score (onset detection)
194
+ - `COH`: Inter-microphone coherence
195
+ - `GATE`: Gate state (OPEN/CLOSED)
196
+ - `MODEL`: Model inference time (ms)
197
+ - `HIST`: Histogram processing time (ms)
198
+ - `DOA(R=..., n=...)`: R_clip value and number of detected peaks
199
+ - `[angles]`: Detected azimuth angles in degrees
200
+
201
+ ### Visual Output
202
+
203
+ A polar plot window shows:
204
+ - **Green lines**: Detected sound source directions
205
+ - **Line thickness**: Proportional to confidence score
206
+ - **Angle labels**: Azimuth in degrees (0° = North/front)
207
+
208
+ ### Azimuth Convention
209
+
210
+ - **0°** = North (front of microphone)
211
+ - **90°** = East (right)
212
+ - **180°** = South (back)
213
+ - **270°** = West (left)
214
+
215
+ ## How It Works
216
+
217
+ ### 1. Audio Processing Pipeline
218
+
219
+ ```
220
+ Microphone (6 ch) → Channel Selection [1,4,3,2] → 4-channel audio
221
+ ```
222
+
223
+ ### 2. Feature Extraction
224
+
225
+ For each analysis window:
226
+ - Compute STFT for all 4 channels
227
+ - Extract magnitude, phase, cosine, and sine components
228
+ - Result: `(T_frames, 12_features, F_freq_bins)`
229
+
230
+ ### 3. Model Inference
231
+
232
+ - Batch process features through ONNX model
233
+ - Output: `(T_frames, K_bins)` logits per frame
234
+ - Each frame has K probability scores for different azimuth angles
235
+
236
+ ### 4. Histogram Aggregation
237
+
238
+ - Apply softmax with temperature `tau` to logits
239
+ - Weight by circular coherence (R_clip)
240
+ - Aggregate across all frames into a single histogram
241
+ - Smooth the histogram
242
+
243
+ ### 5. Peak Detection
244
+
245
+ - Find local maxima in the histogram
246
+ - Filter by minimum height, separation, and window mass
247
+ - Refine peak positions using parabolic interpolation
248
+ - Return up to `max_sources` peaks
249
+
250
+ ### 6. Event Gating
251
+
252
+ - Track audio level with exponential moving average
253
+ - Open gate when:
254
+ - Level increases by `level_delta_on_db` OR
255
+ - Valid peaks detected AND R_clip > `min_R_clip`
256
+ - Close gate when level drops and no valid peaks
257
+ - Apply hold and refractory periods to prevent flickering
258
+
259
+ ## Troubleshooting
260
+
261
+ ### "Invalid number of channels" Error
262
+
263
+ **Problem**: Device reports 0 channels or PyAudio can't open it.
264
+
265
+ **Solution**:
266
+ 1. Stop PulseAudio: `pulseaudio --kill`
267
+ 2. Run the script
268
+ 3. Restart PulseAudio: `pulseaudio --start`
269
+
270
+ Or use the helper script `run_onnx_stream.sh`.
271
+
272
+ ### No Audio Detected
273
+
274
+ - Check microphone connections
275
+ - Verify device index with `--list-devices`
276
+ - Check audio levels (should be above `level_min_dbfs`)
277
+ - Adjust `level_delta_on_db` to be more sensitive
278
+
279
+ ### GPU Not Used
280
+
281
+ - Verify CUDA is available: `python -c "import torch; print(torch.cuda.is_available())"`
282
+ - Install `onnxruntime-gpu` instead of `onnxruntime`
283
+ - Check that CUDA providers are listed in the model loading message
284
+
285
+ ### Model Mismatch Errors
286
+
287
+ - Ensure `--K` matches the model's K value (usually 72)
288
+ - Check that the ONNX model was exported with the correct input shape
289
+ - Verify config.yaml matches training configuration
290
+
291
+ ### Poor DOA Accuracy
292
+
293
+ - Increase `--window-ms` for longer analysis windows (more stable)
294
+ - Adjust `--min-peak-height` and `--min-window-mass` thresholds
295
+ - Tune `--tau` (lower = sharper peaks, higher = smoother)
296
+ - Check microphone array calibration and positioning
297
+
298
+ ## Performance Tips
299
+
300
+ - **GPU Inference**: Use `onnxruntime-gpu` for 5-10x speedup
301
+ - **Window Size**: Larger windows (400ms) = more stable but higher latency
302
+ - **Hop Size**: Smaller hops (50ms) = more responsive but more computation
303
+ - **Batch Size**: The script uses batch_size=25 internally for efficient GPU usage
304
+
305
+ ## Stopping the Script
306
+
307
+ Press **Ctrl+C** to stop the stream. The script will:
308
+ - Close the audio stream
309
+ - Close the visualization window
310
+ - Clean up resources
311
+
312
+ ## Integration
313
+
314
+ To use this in your own code, see `onnx_doa_inference.py` which provides a standalone inference class that can be integrated into other projects.
315
+
316
+
317
+
doa_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:148f874c4fdc302a4d1808d6d0a45e1b3b40aeb6f8c6b2ef7e423710b2b28cba
3
+ size 785447
features.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from numpy.fft import rfft
3
+ from numpy.lib.stride_tricks import as_strided
4
+ from scipy.signal import get_window
5
+
6
+ def stft_multi(
7
+ x,
8
+ fs: float,
9
+ win_s: float = 0.032,
10
+ hop_s: float = 0.010,
11
+ nfft: int | None = None,
12
+ window: str | tuple | np.ndarray = "hann",
13
+ center: bool = True,
14
+ pad_mode: str = "reflect",
15
+ out_dtype = np.complex64,
16
+ ):
17
+ """
18
+ Multichannel STFT (vectorized).
19
+ Args
20
+ ----
21
+ x : np.ndarray, shape (N, C) time-domain signal
22
+ fs : float, sampling rate (Hz)
23
+ win_s : float, window length in seconds (default 32 ms)
24
+ hop_s : float, hop length in seconds (default 10 ms)
25
+ nfft : int or None. If None, uses next power of two >= frame_len
26
+ window : str/tuple/array for scipy.signal.get_window or a length-L array
27
+ center : if True, pad by L//2 on both sides (librosa-style)
28
+ pad_mode: np.pad mode (e.g., "reflect", "constant")
29
+ out_dtype: dtype for STFT output (complex64 recommended)
30
+
31
+ Returns
32
+ -------
33
+ X : np.ndarray, shape (T, C, F) complex STFT
34
+ freqs: np.ndarray, shape (F,) frequency bins in Hz
35
+ times: np.ndarray, shape (T,) frame center times in seconds
36
+ """
37
+ x = np.asarray(x)
38
+ if x.ndim == 1:
39
+ x = x[:, None] # (N,1)
40
+ assert x.ndim == 2, "x must be (samples, channels)"
41
+ N, C = x.shape
42
+
43
+ # Window & hop in samples
44
+ frame_len = int(round(win_s * fs))
45
+ hop = int(round(hop_s * fs))
46
+ if frame_len <= 0 or hop <= 0:
47
+ raise ValueError("win_s and hop_s must be > 0")
48
+
49
+ # FFT size
50
+ def _next_pow2(n):
51
+ return 1 << (int(n - 1).bit_length())
52
+ nfft = _next_pow2(frame_len) if nfft is None else int(nfft)
53
+ if nfft < frame_len:
54
+ raise ValueError("nfft must be >= frame_len")
55
+
56
+ # Window vector
57
+ if isinstance(window, np.ndarray):
58
+ w = window.astype(float, copy=False)
59
+ else:
60
+ w = get_window(window, frame_len, fftbins=True).astype(float)
61
+ if w.shape[0] != frame_len:
62
+ raise ValueError("Provided window length != frame_len")
63
+
64
+ # Optional centering (pad by L//2 on both sides)
65
+ pad = frame_len // 2 if center else 0
66
+ if pad > 0:
67
+ x_pad = np.pad(x, ((pad, pad), (0, 0)), mode=pad_mode)
68
+ else:
69
+ x_pad = x
70
+
71
+ Np = x_pad.shape[0]
72
+ if Np < frame_len:
73
+ # ensure at least one frame
74
+ x_pad = np.pad(x_pad, ((0, frame_len - Np), (0, 0)), mode=pad_mode)
75
+ Np = x_pad.shape[0]
76
+
77
+ # Number of frames
78
+ T = 1 + (Np - frame_len) // hop
79
+ if T <= 0:
80
+ raise ValueError("Signal too short for given window/hop")
81
+
82
+ # Stride-trick framing: (T, frame_len, C) view into x_pad
83
+ s_t, s_c = x_pad.strides # bytes per step in time/channel
84
+ frames = as_strided(
85
+ x_pad,
86
+ shape=(T, frame_len, C),
87
+ strides=(hop * s_t, s_t, s_c),
88
+ writeable=False,
89
+ )
90
+ # Reorder to (T, C, frame_len) to apply window & FFT along the last axis
91
+ frames = np.transpose(frames, (0, 2, 1)) # (T, C, L)
92
+
93
+ # Apply window (broadcast over T and C)
94
+ frames = frames * w[None, None, :]
95
+
96
+ # Batched real FFT along last axis -> (T, C, F)
97
+ X = rfft(frames, n=nfft, axis=-1).astype(out_dtype, copy=False)
98
+
99
+ # Frequency and time vectors
100
+ F = X.shape[-1]
101
+ freqs = (fs / nfft) * np.arange(F)
102
+ # Frame centers relative to original signal
103
+ if center:
104
+ # centers at sample indices: t*hop (librosa convention)
105
+ times = (np.arange(T) * hop) / fs
106
+ else:
107
+ # window centered at (frame_len/2) + t*hop
108
+ times = (np.arange(T) * hop + frame_len / 2.0) / fs
109
+
110
+ return X, freqs, times
111
+
112
+
113
+
114
+ def _wrap_to_2pi(x: np.ndarray) -> np.ndarray:
115
+ """Wrap angles to [0, 2π)."""
116
+ return np.mod(x, 2.0 * np.pi)
117
+
118
+ def compute_mag_phase(
119
+ X: np.ndarray,
120
+ dtype=np.float32,
121
+ ):
122
+ """
123
+ Per-channel magnitude and absolute phase (wrapped to [0, 2π)).
124
+
125
+ Args
126
+ ----
127
+ X : np.ndarray, shape (T, C, F), complex STFT
128
+ dtype: output dtype
129
+
130
+ Returns
131
+ -------
132
+ mag : np.ndarray, shape (T, C, F) = |X|
133
+ phase : np.ndarray, shape (T, C, F) = angle(X) in [0, 2π)
134
+ """
135
+ assert X.ndim == 3, "X must be (T, C, F)"
136
+ mag = np.abs(X).astype(dtype, copy=False)
137
+ phase = _wrap_to_2pi(np.angle(X)).astype(dtype, copy=False)
138
+ return mag, phase
139
+
140
+ def compute_mag_phase_cos_sin(
141
+ X: np.ndarray,
142
+ dtype=np.float32,
143
+ ):
144
+ """
145
+ Concatenate per-channel magnitude, cos(phase), sin(phase).
146
+
147
+ Args
148
+ ----
149
+ X : np.ndarray, shape (T, C, F), complex STFT
150
+ dtype: output dtype
151
+
152
+ Returns
153
+ -------
154
+ feats : np.ndarray, shape (T, 3*C, F)
155
+ Layout = [mag (C), cos(phase) (C), sin(phase) (C)]
156
+ where phase is angle(X) wrapped to [0, 2π).
157
+ """
158
+ mag, phase = compute_mag_phase(X, dtype=dtype)
159
+ cos_phase = np.cos(phase).astype(dtype, copy=False)
160
+ sin_phase = np.sin(phase).astype(dtype, copy=False)
161
+ feats = np.concatenate([mag, cos_phase, sin_phase], axis=1)
162
+ return feats
163
+
164
+ def compute_real_imag_features(
165
+ X: np.ndarray,
166
+ dtype=np.float32,
167
+ ):
168
+ """
169
+ Concatenate per-channel real and imaginary parts.
170
+
171
+ Args
172
+ ----
173
+ X : np.ndarray, shape (T, C, F), complex STFT
174
+ dtype: output dtype
175
+
176
+ Returns
177
+ -------
178
+ feats : np.ndarray, shape (T, 2*C, F)
179
+ Layout = [Re (C), Im (C)]
180
+ """
181
+ assert X.ndim == 3, "X must be (T, C, F)"
182
+ real = X.real.astype(dtype, copy=False)
183
+ imag = X.imag.astype(dtype, copy=False)
184
+ feats = np.concatenate([real, imag], axis=1)
185
+ return feats
186
+
onnx_stream_microphone.py ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Real-time DOA inference using ONNX model with microphone streaming.
4
+ Includes histogram-based detection, event gates, and onset detection.
5
+ """
6
+
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ # Add src directory to Python path
11
+ project_root = Path(__file__).parent
12
+ src_dir = project_root / "src"
13
+ if str(src_dir) not in sys.path:
14
+ sys.path.insert(0, str(src_dir))
15
+
16
+ import math
17
+ import numpy as np
18
+ import time
19
+ import queue
20
+ import argparse
21
+ import pyaudio
22
+ import onnxruntime as ort
23
+ import yaml
24
+ from typing import Optional, Dict, List, Tuple
25
+ import matplotlib.pyplot as plt
26
+ from matplotlib.patches import Circle
27
+ import torch
28
+ import torch.nn.functional as F
29
+
30
+ from mirokai_doa.features import stft_multi, compute_mag_phase_cos_sin
31
+
32
+
33
+
34
+
35
+ # -------------------------
36
+ # Math helpers (numpy version)
37
+ # -------------------------
38
+ def _angles_deg_np(K: int):
39
+ bin_size = 360.0 / K
40
+ deg = (np.arange(K, dtype=np.float32) + 0.5) * bin_size
41
+ rad = deg * np.pi / 180.0
42
+ return deg, np.cos(rad), np.sin(rad), bin_size
43
+
44
+ def _softmax_temp_np(logits: np.ndarray, tau: float = 0.8) -> np.ndarray:
45
+ exp_logits = np.exp((logits - np.max(logits, axis=-1, keepdims=True)) / tau)
46
+ return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
47
+
48
+ def _circular_window_sum_np(row: np.ndarray, idx: int, half_w: int) -> float:
49
+ K = row.size
50
+ if half_w <= 0:
51
+ return float(row[idx])
52
+ acc = 0.0
53
+ for d in range(-half_w, half_w + 1):
54
+ acc += float(row[(idx + d) % K])
55
+ return acc
56
+
57
+ def _parabolic_peak_refine_np(row: np.ndarray, k: int) -> float:
58
+ K = row.size
59
+ km1, kp1 = (k - 1) % K, (k + 1) % K
60
+ y1, y2, y3 = float(row[km1]), float(row[k]), float(row[kp1])
61
+ denom = (y1 - 2 * y2 + y3)
62
+ if abs(denom) < 1e-9:
63
+ return 0.0
64
+ delta = 0.5 * (y1 - y3) / denom
65
+ return float(max(min(delta, 0.5), -0.5))
66
+
67
+ def _min_circ_separation_bins(a: int, chosen: List[int], K: int) -> int:
68
+ if not chosen:
69
+ return K
70
+ dmin = K
71
+ for j in chosen:
72
+ d = abs(a - j)
73
+ d = min(d, K - d)
74
+ dmin = min(dmin, d)
75
+ return dmin
76
+
77
+
78
+ # -------------------------
79
+ # Audio helpers
80
+ # -------------------------
81
+ def byte_to_float(data: bytes) -> np.ndarray:
82
+ samples = np.frombuffer(data, dtype=np.int16)
83
+ return samples.astype(np.float32) / 32768.0
84
+
85
+ def chunk_to_floatarray(data: bytes, channels: int) -> np.ndarray:
86
+ float_data = byte_to_float(data)
87
+ return float_data.reshape(-1, channels).T
88
+
89
+ def rms_dbfs(x: np.ndarray, eps: float = 1e-9) -> float:
90
+ val = np.sqrt((x * x).mean())
91
+ return 20.0 * np.log10(max(val, eps))
92
+
93
+ def frame_rms_energy(audio_buffer: np.ndarray, T: int) -> np.ndarray:
94
+ """Split audio_buffer (C,N) into T equal segments; return per-frame RMS (normalized)."""
95
+ C, N = audio_buffer.shape
96
+ if T <= 0:
97
+ return np.ones(1, dtype=np.float32)
98
+ edges = np.linspace(0, N, T + 1, dtype=int)
99
+ e = []
100
+ for i in range(T):
101
+ seg = audio_buffer[:, edges[i]:edges[i+1]]
102
+ if seg.size == 0:
103
+ e.append(0.0)
104
+ else:
105
+ rms = np.sqrt((seg * seg).mean())
106
+ e.append(rms)
107
+ e = np.asarray(e, dtype=np.float32)
108
+ e = e / max(e.mean(), 1e-6)
109
+ return e
110
+
111
+ def spectral_flux_per_frame(audio_buffer: np.ndarray, T: int) -> np.ndarray:
112
+ """Compute per-frame spectral flux across T segments from mono mix."""
113
+ C, N = audio_buffer.shape
114
+ if T <= 1:
115
+ return np.zeros((T,), dtype=np.float32)
116
+ mono = audio_buffer.mean(axis=0)
117
+ edges = np.linspace(0, N, T + 1, dtype=int)
118
+ mags = []
119
+ for i in range(T):
120
+ seg = mono[edges[i]:edges[i+1]]
121
+ if seg.size == 0:
122
+ mags.append(np.zeros(1, dtype=np.float32))
123
+ continue
124
+ win = np.hanning(len(seg)) if len(seg) > 8 else np.ones_like(seg)
125
+ S = np.fft.rfft(seg * win, n=len(seg))
126
+ mags.append(np.abs(S).astype(np.float32))
127
+ flux = np.zeros(T, dtype=np.float32)
128
+ for t in range(1, T):
129
+ a = mags[t-1]
130
+ b = mags[t]
131
+ L = min(len(a), len(b))
132
+ if L == 0:
133
+ flux[t] = 0.0
134
+ continue
135
+ diff = b[:L] - a[:L]
136
+ pos = np.maximum(diff, 0.0)
137
+ denom = np.sum(b[:L]) + 1e-6
138
+ flux[t] = float(np.sum(pos) / denom)
139
+ return flux
140
+
141
+
142
+ # -------------------------
143
+ # Onset detector
144
+ # -------------------------
145
+ class OnsetDetector:
146
+ def __init__(self, alpha: float = 0.05):
147
+ self.alpha = float(alpha)
148
+ self.mu = 0.0
149
+ self.var = 1.0
150
+ self.inited = False
151
+
152
+ def update_flux(self, flux_recent: float) -> float:
153
+ if not self.inited:
154
+ self.mu = flux_recent
155
+ self.var = 1e-3 + abs(flux_recent)
156
+ self.inited = True
157
+ delta = flux_recent - self.mu
158
+ self.mu += self.alpha * delta
159
+ self.var = (1 - self.alpha) * self.var + self.alpha * delta * delta
160
+ sigma = max(np.sqrt(self.var), 1e-6)
161
+ z = (flux_recent - self.mu) / sigma
162
+ return float(z)
163
+
164
+ @staticmethod
165
+ def last_segment_coherence(audio_buffer: np.ndarray, T: int,
166
+ pairs: List[Tuple[int,int]] = [(0,1),(0,2),(0,3)]) -> float:
167
+ C, N = audio_buffer.shape
168
+ if T < 1:
169
+ return 0.0
170
+ edges = np.linspace(0, N, T + 1, dtype=int)
171
+ s0, s1 = int(edges[-2]), int(edges[-1])
172
+ seg = audio_buffer[:, s0:s1]
173
+ if seg.shape[1] < 16:
174
+ return 0.0
175
+ rmax = 0.0
176
+ for (i,j) in pairs:
177
+ xi = seg[i] - seg[i].mean()
178
+ xj = seg[j] - seg[j].mean()
179
+ denom = (np.linalg.norm(xi) * np.linalg.norm(xj) + 1e-9)
180
+ r = float(np.dot(xi, xj) / denom)
181
+ rmax = max(rmax, abs(r))
182
+ return rmax
183
+
184
+
185
+ # -------------------------
186
+ # Histogram DOA detector (numpy/torch hybrid)
187
+ # -------------------------
188
+ class HistDOADetector:
189
+ def __init__(
190
+ self,
191
+ K: int = 72,
192
+ tau: float = 0.8,
193
+ gamma: float = 1.5,
194
+ smooth_k: int = 1,
195
+ window_bins: int = 1,
196
+ min_peak_height: float = 0.10,
197
+ min_window_mass: float = 0.24,
198
+ min_sep_deg: float = 20.0,
199
+ min_active_ratio: float = 0.20,
200
+ max_sources: int = 3,
201
+ device: str = "cpu",
202
+ ):
203
+ self.K = int(K)
204
+ self.tau = float(tau)
205
+ self.gamma = float(gamma)
206
+ self.smooth_k = int(smooth_k)
207
+ self.window_bins = int(window_bins)
208
+ self.min_peak_height = float(min_peak_height)
209
+ self.min_window_mass = float(min_window_mass)
210
+ self.min_sep_deg = float(min_sep_deg)
211
+ self.min_active_ratio = float(min_active_ratio)
212
+ self.max_sources = int(max_sources)
213
+ self.device = torch.device(device)
214
+ self._deg, self._cos, self._sin, self._bin_size = self._angles_deg(self.K)
215
+
216
+ def _angles_deg(self, K: int):
217
+ bin_size = 360.0 / K
218
+ deg = torch.arange(K, device=self.device, dtype=torch.float32) + 0.5
219
+ deg = deg * bin_size
220
+ rad = deg * math.pi / 180.0
221
+ return deg, torch.cos(rad), torch.sin(rad), bin_size
222
+
223
+ def _aggregate_histogram(self, logits: np.ndarray, mask: np.ndarray) -> Tuple[np.ndarray, float, float]:
224
+ """Aggregate histogram from logits and VAD mask."""
225
+ logits_t = torch.from_numpy(logits).float().to(self.device)
226
+ mask_t = torch.from_numpy(mask).float().to(self.device)
227
+
228
+ probs = F.softmax(logits_t / self.tau, dim=-1) # [T,K]
229
+ T = probs.shape[0]
230
+ m = mask_t
231
+
232
+ # Weighted histogram
233
+ x = torch.matmul(probs, self._cos)
234
+ y = torch.matmul(probs, self._sin)
235
+ R_t = torch.clamp(torch.sqrt(x * x + y * y), 0, 1)
236
+ w = m * (R_t ** self.gamma)
237
+
238
+ if w.sum() <= 0:
239
+ w = torch.ones_like(w) * 1e-6
240
+
241
+ hist = torch.matmul(w, probs)
242
+ hist = hist / hist.sum().clamp_min(1e-8)
243
+
244
+ if self.smooth_k > 0:
245
+ s = self.smooth_k
246
+ pad = torch.cat([hist[-s:], hist, hist[:s]], dim=0).view(1, 1, -1)
247
+ kernel = torch.ones(1, 1, 2 * s + 1, device=self.device) / (2 * s + 1)
248
+ hist = F.conv1d(pad, kernel, padding=0).view(-1)
249
+
250
+ X = torch.dot(hist, self._cos)
251
+ Y = torch.dot(hist, self._sin)
252
+ R_clip = float(torch.sqrt(X * X + Y * Y).item())
253
+ active_ratio = float(m.mean().item())
254
+ return hist.detach().cpu().numpy(), active_ratio, R_clip
255
+
256
+ def _pick_peaks(self, hist: np.ndarray) -> List[Dict[str, float]]:
257
+ """Pick peaks from histogram."""
258
+ hist_t = torch.from_numpy(hist).float()
259
+ K = self.K
260
+ bin_size = self._bin_size
261
+
262
+ left = torch.roll(hist_t, 1, 0)
263
+ right = torch.roll(hist_t, -1, 0)
264
+ cand_idxs = ((hist_t > left) & (hist_t > right)).nonzero(as_tuple=False).flatten().tolist()
265
+ cand_idxs.sort(key=lambda i: float(hist_t[i].item()), reverse=True)
266
+
267
+ chosen, out = [], []
268
+ min_sep_bins = max(1, int(round(self.min_sep_deg / bin_size)))
269
+
270
+ for idx in cand_idxs:
271
+ if _min_circ_separation_bins(idx, chosen, K) < min_sep_bins:
272
+ continue
273
+ if float(hist_t[idx].item()) < self.min_peak_height:
274
+ continue
275
+ mass = _circular_window_sum_np(hist, idx, self.window_bins)
276
+ if mass < self.min_window_mass:
277
+ continue
278
+ delta = _parabolic_peak_refine_np(hist, idx)
279
+ angle_deg = ((idx + 0.5 + delta) * bin_size) % 360.0
280
+ out.append({"azimuth_deg": angle_deg, "score": float(mass)})
281
+ chosen.append(idx)
282
+ if len(out) >= self.max_sources:
283
+ break
284
+ return out
285
+
286
+ def detect(self, logits: np.ndarray) -> Dict[str, any]:
287
+ """Detect DOA from logits (no VAD separation)."""
288
+ # Use all frames (no VAD masking)
289
+ mask = np.ones(logits.shape[0], dtype=np.float32)
290
+
291
+ hist, active_ratio, R_clip = self._aggregate_histogram(logits, mask)
292
+
293
+ peaks = self._pick_peaks(hist) if active_ratio >= self.min_active_ratio else []
294
+
295
+ bins_deg = (np.arange(self.K) + 0.5) * (360.0 / self.K)
296
+ return {
297
+ "peaks": peaks,
298
+ "active_ratio": active_ratio,
299
+ "R_clip": R_clip,
300
+ "hist": hist,
301
+ "bins_deg": bins_deg,
302
+ "has_event": bool(peaks),
303
+ }
304
+
305
+
306
+ # -------------------------
307
+ # Event gate
308
+ # -------------------------
309
+ class LevelChangeGate:
310
+ def __init__(
311
+ self,
312
+ delta_on_db: float = 2.5,
313
+ delta_off_db: float = 1.0,
314
+ level_min_dbfs: float = -60.0,
315
+ ema_alpha: float = 0.05,
316
+ min_R_clip: float = 0.18,
317
+ hold_ms: int = 300,
318
+ refractory_ms: int = 120
319
+ ):
320
+ self.delta_on_db = float(delta_on_db)
321
+ self.delta_off_db = float(delta_off_db)
322
+ self.level_min_dbfs = float(level_min_dbfs)
323
+ self.ema_alpha = float(ema_alpha)
324
+ self.min_R_clip = float(min_R_clip)
325
+ self.hold_s = float(hold_ms) / 1000.0
326
+ self.refractory_s = float(refractory_ms) / 1000.0
327
+ self.bg_dbfs = None
328
+ self.active = False
329
+ self.last_change_time = 0.0
330
+
331
+ def update(self, level_dbfs: float, now_s: float,
332
+ peaks_count: int, R_clip_max: float):
333
+ if self.bg_dbfs is None:
334
+ self.bg_dbfs = level_dbfs
335
+ diff_db = level_dbfs - self.bg_dbfs
336
+
337
+ want_open = (
338
+ (now_s - self.last_change_time) >= self.refractory_s and
339
+ ((level_dbfs > self.level_min_dbfs and diff_db >= self.delta_on_db) or
340
+ (peaks_count > 0 and R_clip_max >= self.min_R_clip))
341
+ )
342
+
343
+ if not self.active:
344
+ if want_open:
345
+ self.active = True
346
+ self.last_change_time = now_s
347
+ else:
348
+ if (now_s - self.last_change_time) >= self.hold_s:
349
+ want_close = (
350
+ (diff_db <= self.delta_off_db) and
351
+ (peaks_count == 0 or R_clip_max < self.min_R_clip)
352
+ )
353
+ if want_close:
354
+ self.active = False
355
+ self.last_change_time = now_s
356
+
357
+ self.bg_dbfs = (1.0 - self.ema_alpha) * self.bg_dbfs + self.ema_alpha * level_dbfs
358
+ return self.active, diff_db
359
+
360
+
361
+ # -------------------------
362
+ # ONNX Inference
363
+ # -------------------------
364
+ class ONNXDOAStreaming:
365
+ def __init__(
366
+ self,
367
+ onnx_path: str,
368
+ config_path: Optional[str] = None,
369
+ providers: Optional[list] = None
370
+ ):
371
+ if config_path is None:
372
+ config_path = project_root / "configs" / "train.yaml"
373
+ with open(config_path, 'r') as f:
374
+ self.config = yaml.safe_load(f)
375
+
376
+ self.features_cfg = self.config.get('features', {})
377
+ self.sr = self.features_cfg.get('sr', 16000)
378
+ self.win_s = self.features_cfg.get('win_s', 0.032)
379
+ self.hop_s = self.features_cfg.get('hop_s', 0.010)
380
+ self.nfft = self.features_cfg.get('nfft', 1024)
381
+ self.K = self.features_cfg.get('K', 72)
382
+
383
+ if providers is None:
384
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
385
+
386
+ sess_options = ort.SessionOptions()
387
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
388
+
389
+ self.session = ort.InferenceSession(onnx_path, sess_options=sess_options, providers=providers)
390
+ self.input_name = self.session.get_inputs()[0].name
391
+ self.output_name = self.session.get_outputs()[0].name
392
+ input_shape = self.session.get_inputs()[0].shape
393
+ self.is_doa_model = input_shape[-1] == 513 if len(input_shape) == 4 else False
394
+
395
+ print(f"ONNX Model loaded: {onnx_path}")
396
+ print(f" Input shape: {input_shape}")
397
+ print(f" Model type: {'DoAEstimator' if self.is_doa_model else 'TFPoolClassifierNoCond'}")
398
+ print(f" Providers: {self.session.get_providers()}")
399
+
400
+ def compute_features(self, mixture: np.ndarray) -> np.ndarray:
401
+ if mixture.ndim == 1:
402
+ raise ValueError("Mixture must be multichannel (4 channels)")
403
+ if mixture.shape[0] != 4 and mixture.shape[1] == 4:
404
+ mixture = mixture.T
405
+
406
+ if mixture.shape[0] != 4:
407
+ raise ValueError(f"Expected 4 channels, got {mixture.shape[0]}")
408
+
409
+ x4 = mixture.astype(np.float32)
410
+ X, freqs, times = stft_multi(x4.T, fs=self.sr, win_s=self.win_s, hop_s=self.hop_s,
411
+ nfft=self.nfft, window="hann", center=True, pad_mode="reflect")
412
+ feats = compute_mag_phase_cos_sin(X, dtype=np.float32)
413
+ return feats
414
+
415
+ def inference_batch(self, feats: np.ndarray, batch_size: int = 25) -> np.ndarray:
416
+ T_frames, C_feat, F = feats.shape
417
+ assert C_feat == 12, f"Expected 12 feature channels, got {C_feat}"
418
+
419
+ all_logits = []
420
+ for start_idx in range(0, T_frames, batch_size):
421
+ end_idx = min(start_idx + batch_size, T_frames)
422
+ batch_feats = feats[start_idx:end_idx]
423
+ batch_T = batch_feats.shape[0]
424
+
425
+ if batch_T < batch_size:
426
+ padding = np.zeros((batch_size - batch_T, C_feat, F), dtype=batch_feats.dtype)
427
+ batch_feats = np.concatenate([batch_feats, padding], axis=0)
428
+
429
+ feats_tensor = batch_feats.transpose(1, 0, 2)[np.newaxis, ...]
430
+ outputs = self.session.run([self.output_name], {self.input_name: feats_tensor.astype(np.float32)})
431
+ batch_logits = outputs[0]
432
+
433
+ if batch_logits.ndim == 2:
434
+ if batch_logits.shape[0] == 1 and batch_logits.shape[1] == self.K:
435
+ batch_logits = np.tile(batch_logits, (batch_T, 1))
436
+ elif batch_logits.shape[0] == 1:
437
+ batch_logits = batch_logits[0]
438
+ else:
439
+ batch_logits = batch_logits[:batch_T]
440
+ elif batch_logits.ndim == 3:
441
+ batch_logits = batch_logits[0, :batch_T]
442
+
443
+ all_logits.append(batch_logits)
444
+
445
+ return np.concatenate(all_logits, axis=0)
446
+
447
+
448
+ # -------------------------
449
+ # Visualization
450
+ # -------------------------
451
+ class CurrentLineVisualizer:
452
+ def __init__(self, title: str = "Current DOA"):
453
+ self.fig = plt.figure(figsize=(7.5, 7.5))
454
+ self.ax = self.fig.add_subplot(111, projection='polar')
455
+ self._setup_axes(title)
456
+ plt.ion()
457
+ plt.show(block=False)
458
+
459
+ def _setup_axes(self, title: str):
460
+ self.ax.clear()
461
+ self.ax.set_title(title, fontsize=13, fontweight='bold', pad=16)
462
+ self.ax.set_theta_zero_location('N')
463
+ self.ax.set_theta_direction(-1)
464
+ self.ax.set_thetalim(0, 2*np.pi)
465
+ self.ax.set_ylim(0, 1.05)
466
+ self.ax.set_yticklabels([])
467
+ self.ax.add_patch(Circle((0, 0), 1.0, fill=False, color='gray', linestyle='--', linewidth=1, alpha=0.5))
468
+ self.ax.grid(alpha=0.2)
469
+
470
+ def update(self, peaks: List[Dict]):
471
+ self._setup_axes("Current DOA")
472
+
473
+ for pk in peaks[:3]:
474
+ az = float(pk["azimuth_deg"])
475
+ sc = float(pk.get("score", 0.2))
476
+ lw = 2.0 + 5.0 * float(np.clip(sc, 0.0, 0.6))
477
+ theta = np.deg2rad(az)
478
+ self.ax.plot([theta, theta], [0.0, 1.0], color='tab:green', linewidth=lw, solid_capstyle='round')
479
+ self.ax.text(theta, 1.02, f"{az:.0f}°", ha='center', va='bottom', fontsize=10,
480
+ color='tab:green', fontweight='bold')
481
+
482
+ self.fig.canvas.draw_idle()
483
+ self.fig.canvas.flush_events()
484
+
485
+
486
+
487
+
488
+ # -------------------------
489
+ # Main streaming function
490
+ # -------------------------
491
+ def stream_onnx_inference(
492
+ onnx_path: str,
493
+ config_path: Optional[str] = None,
494
+ device_index: Optional[int] = None,
495
+ sample_rate: int = 16000,
496
+ window_ms: int = 200,
497
+ hop_ms: int = 100,
498
+ chunk_size: int = 1600,
499
+ cpu_only: bool = False,
500
+ # Histogram params
501
+ K: int = 72,
502
+ tau: float = 0.8,
503
+ smooth_k: int = 1,
504
+ min_peak_height: float = 0.10,
505
+ min_window_mass: float = 0.24,
506
+ min_sep_deg: float = 20.0,
507
+ min_active_ratio: float = 0.20,
508
+ max_sources: int = 3,
509
+ # Event gate params
510
+ level_delta_on_db: float = 2.5,
511
+ level_delta_off_db: float = 1.0,
512
+ level_min_dbfs: float = -60.0,
513
+ level_ema_alpha: float = 0.05,
514
+ event_hold_ms: int = 300,
515
+ min_R_clip: float = 0.18,
516
+ event_refractory_ms: int = 120,
517
+ # Onset params
518
+ onset_alpha: float = 0.05,
519
+ ):
520
+ """Stream inference from microphone using ONNX model."""
521
+
522
+ providers = ['CPUExecutionProvider'] if cpu_only else ['CUDAExecutionProvider', 'CPUExecutionProvider']
523
+ infer = ONNXDOAStreaming(onnx_path, config_path, providers=providers)
524
+
525
+ # Override K if provided
526
+ if K != infer.K:
527
+ print(f"Warning: K mismatch. Model K={infer.K}, requested K={K}. Using model K.")
528
+ K = infer.K
529
+
530
+ det = HistDOADetector(
531
+ K=K, tau=tau, gamma=1.5, smooth_k=smooth_k,
532
+ window_bins=1, min_peak_height=min_peak_height, min_window_mass=min_window_mass,
533
+ min_sep_deg=min_sep_deg, min_active_ratio=min_active_ratio, max_sources=max_sources,
534
+ device="cuda" if not cpu_only and torch.cuda.is_available() else "cpu"
535
+ )
536
+
537
+ gate = LevelChangeGate(
538
+ delta_on_db=level_delta_on_db, delta_off_db=level_delta_off_db,
539
+ level_min_dbfs=level_min_dbfs, ema_alpha=level_ema_alpha,
540
+ min_R_clip=min_R_clip,
541
+ hold_ms=event_hold_ms, refractory_ms=event_refractory_ms
542
+ )
543
+
544
+ onset = OnsetDetector(alpha=onset_alpha)
545
+ visualizer = CurrentLineVisualizer()
546
+
547
+ window_samples = int(sample_rate * window_ms / 1000)
548
+ hop_samples = int(sample_rate * hop_ms / 1000)
549
+
550
+ p = pyaudio.PyAudio()
551
+
552
+ if device_index is None:
553
+ for i in range(p.get_device_count()):
554
+ info = p.get_device_info_by_index(i)
555
+ name = info['name'].lower()
556
+ # Check by name first (PulseAudio might hide channels)
557
+ if 'respeaker' in name or 'seeed' in name or '2886' in name:
558
+ device_index = i
559
+ print(f"Auto-detected ReSpeaker at device {i}: {info['name']}")
560
+ break
561
+
562
+ if device_index is None:
563
+ print("\n[Audio] Could not auto-detect Respeaker. Use --device-index or --list-devices.\n")
564
+ p.terminate()
565
+ return
566
+
567
+ # Check device info
568
+ device_info = p.get_device_info_by_index(device_index)
569
+ print(f"Device info: {device_info['name']}")
570
+ print(f" Max input channels: {device_info['maxInputChannels']}")
571
+ print(f" Default sample rate: {device_info['defaultSampleRate']:.0f} Hz")
572
+
573
+ # If device shows 0 channels, it's likely managed by PulseAudio
574
+ # We'll still try to open it - sometimes it works despite the report
575
+ if device_info['maxInputChannels'] == 0:
576
+ print(" Warning: Device reports 0 channels (may be managed by PulseAudio)")
577
+ print(" Attempting to open anyway...")
578
+
579
+ CHANNELS = 6
580
+ RAW_CHANNELS = [1, 4, 3, 2] # your requested order
581
+ FORMAT = pyaudio.paInt16
582
+
583
+ audio_buffer = np.zeros((4, window_samples), dtype=np.float32)
584
+ buffer_fill = 0
585
+ start_time = time.time()
586
+
587
+ audio_queue = queue.Queue()
588
+ stream_closed = False
589
+
590
+ def _fill_buffer(in_data, frame_count, time_info, status_flags):
591
+ if not stream_closed:
592
+ audio_queue.put(in_data)
593
+ return None, pyaudio.paContinue
594
+
595
+ try:
596
+ # Try to open the stream - PyAudio will validate channels
597
+ stream = p.open(
598
+ format=FORMAT,
599
+ channels=CHANNELS,
600
+ rate=sample_rate,
601
+ input=True,
602
+ input_device_index=device_index,
603
+ frames_per_buffer=chunk_size,
604
+ stream_callback=_fill_buffer
605
+ )
606
+ print(" Successfully opened audio stream with 6 channels")
607
+ except Exception as e:
608
+ print(f"\n[Audio] Could not open input device (index {device_index}).")
609
+ print(f" Error: {e}")
610
+ print("\n The ReSpeaker device is likely locked by PulseAudio.")
611
+ print(" Solutions:")
612
+ print(" 1. Temporarily stop PulseAudio: pulseaudio --kill")
613
+ print(" 2. Then restart it after: pulseaudio --start")
614
+ print(" 3. Or configure PulseAudio to allow direct ALSA access\n")
615
+ p.terminate()
616
+ return
617
+
618
+ stream.start_stream()
619
+ print(f"\n[Streaming] Started. Window: {window_ms}ms, Hop: {hop_ms}ms")
620
+ print(" Press Ctrl+C to stop.\n")
621
+
622
+ try:
623
+ while True:
624
+ try:
625
+ data = audio_queue.get(timeout=1.0)
626
+ except queue.Empty:
627
+ continue
628
+
629
+ chunk_all = chunk_to_floatarray(data, CHANNELS) # (6, N)
630
+ audio_chunk = chunk_all[RAW_CHANNELS, :] # (4, N)
631
+ n = audio_chunk.shape[1]
632
+
633
+ if buffer_fill + n <= window_samples:
634
+ audio_buffer[:, buffer_fill:buffer_fill + n] = audio_chunk
635
+ buffer_fill += n
636
+ continue
637
+
638
+ remaining = window_samples - buffer_fill
639
+ if remaining > 0:
640
+ audio_buffer[:, buffer_fill:] = audio_chunk[:, :remaining]
641
+ buffer_fill = window_samples
642
+
643
+ # Inference
644
+ t0 = time.perf_counter()
645
+ feats = infer.compute_features(audio_buffer)
646
+ logits = infer.inference_batch(feats)
647
+ t_model = (time.perf_counter() - t0) * 1000.0
648
+
649
+ T = logits.shape[0]
650
+ energies = frame_rms_energy(audio_buffer, T)
651
+ flux = spectral_flux_per_frame(audio_buffer, T)
652
+ flux_recent = float(max(flux[-1], flux[-2] if T >= 2 else 0.0))
653
+ flux_z = onset.update_flux(flux_recent)
654
+ coh = OnsetDetector.last_segment_coherence(audio_buffer, T)
655
+
656
+ # DOA detection (no VAD)
657
+ t1 = time.perf_counter()
658
+ det_result = det.detect(logits)
659
+ t_hist = (time.perf_counter() - t1) * 1000.0
660
+
661
+ peaks = det_result["peaks"]
662
+ peaks_count = len(peaks)
663
+ Rmax = det_result["R_clip"]
664
+
665
+ level = rms_dbfs(audio_buffer)
666
+ now = time.time() - start_time
667
+
668
+ gate_open, diff_db = gate.update(level_dbfs=level, now_s=now,
669
+ peaks_count=peaks_count, R_clip_max=Rmax)
670
+
671
+ if gate_open:
672
+ visualizer.update(peaks)
673
+ gate_str = "OPEN "
674
+ else:
675
+ visualizer.update([])
676
+ gate_str = "CLOSED"
677
+
678
+ print(f"[{now:6.2f}s] LVL={level:6.1f} dBFS diff={diff_db:+4.1f} | "
679
+ f"FLUXz={flux_z:4.2f} COH={coh:4.2f} | "
680
+ f"GATE={gate_str} | "
681
+ f"MODEL={t_model:5.1f}ms HIST={t_hist:5.1f}ms | "
682
+ f"DOA(R={Rmax:.2f}, n={peaks_count})", end="")
683
+ if peaks:
684
+ az_str = ", ".join([f"{p['azimuth_deg']:.0f}°" for p in peaks[:3]])
685
+ print(f" [{az_str}]")
686
+ else:
687
+ print()
688
+
689
+ # Slide buffer
690
+ audio_buffer[:, :-hop_samples] = audio_buffer[:, hop_samples:]
691
+ buffer_fill = window_samples - hop_samples
692
+
693
+ if n > remaining:
694
+ carry = min(n - remaining, hop_samples)
695
+ if carry > 0:
696
+ audio_buffer[:, buffer_fill:buffer_fill + carry] = audio_chunk[:, remaining:remaining + carry]
697
+ buffer_fill += carry
698
+
699
+ except KeyboardInterrupt:
700
+ print("\n[Streaming] Stopped by user.")
701
+ finally:
702
+ stream_closed = True
703
+ try:
704
+ stream.stop_stream()
705
+ stream.close()
706
+ except Exception:
707
+ pass
708
+ p.terminate()
709
+ plt.close('all')
710
+
711
+
712
+ def main():
713
+ parser = argparse.ArgumentParser(description="Stream ONNX DOA inference from microphone")
714
+ parser.add_argument('--onnx', type=str, required=False, help='Path to ONNX model file')
715
+ parser.add_argument('--config', type=str, default=None, help='Path to config.yaml')
716
+ parser.add_argument('--device-index', type=int, default=None, help='Audio device index')
717
+ parser.add_argument('--sample-rate', type=int, default=16000, help='Sample rate (Hz)')
718
+ parser.add_argument('--window-ms', type=int, default=200, help='Window length (ms)')
719
+ parser.add_argument('--hop-ms', type=int, default=100, help='Hop length (ms)')
720
+ parser.add_argument('--chunk-size', type=int, default=1600, help='Audio chunk size')
721
+ parser.add_argument('--cpu-only', action='store_true', help='Use CPU only')
722
+ parser.add_argument('--list-devices', action='store_true', help='List available audio devices')
723
+
724
+ # Histogram params
725
+ parser.add_argument('--K', type=int, default=72, help='Number of azimuth bins')
726
+ parser.add_argument('--tau', type=float, default=0.8, help='Softmax temperature')
727
+ parser.add_argument('--smooth-k', type=int, default=1, help='Smoothing kernel size')
728
+ parser.add_argument('--min-peak-height', type=float, default=0.10, help='Min peak height')
729
+ parser.add_argument('--min-window-mass', type=float, default=0.24, help='Min window mass')
730
+ parser.add_argument('--min-sep-deg', type=float, default=20.0, help='Min separation (deg)')
731
+ parser.add_argument('--min-active-ratio', type=float, default=0.20, help='Min active ratio')
732
+ parser.add_argument('--max-sources', type=int, default=3, help='Max sources')
733
+
734
+ # Event gate params
735
+ parser.add_argument('--level-delta-on-db', type=float, default=2.5, help='Level delta on (dB)')
736
+ parser.add_argument('--level-delta-off-db', type=float, default=1.0, help='Level delta off (dB)')
737
+ parser.add_argument('--level-min-dbfs', type=float, default=-60.0, help='Min level (dBFS)')
738
+ parser.add_argument('--level-ema-alpha', type=float, default=0.05, help='Level EMA alpha')
739
+ parser.add_argument('--event-hold-ms', type=int, default=300, help='Event hold (ms)')
740
+ parser.add_argument('--min-R-clip', type=float, default=0.18, help='Min R clip')
741
+ parser.add_argument('--event-refractory-ms', type=int, default=120, help='Event refractory (ms)')
742
+
743
+ # Onset params
744
+ parser.add_argument('--onset-alpha', type=float, default=0.05, help='Onset EMA alpha')
745
+
746
+ args = parser.parse_args()
747
+
748
+ if args.list_devices:
749
+ p = pyaudio.PyAudio()
750
+ print("\nAvailable audio input devices:")
751
+ print("-" * 80)
752
+ for i in range(p.get_device_count()):
753
+ info = p.get_device_info_by_index(i)
754
+ if info['maxInputChannels'] > 0:
755
+ print(f"Device {i}: {info['name']}")
756
+ print(f" Channels: {info['maxInputChannels']}, Sample Rate: {info['defaultSampleRate']:.0f} Hz\n")
757
+ p.terminate()
758
+ return
759
+
760
+ if args.onnx is None:
761
+ parser.error("--onnx is required (unless using --list-devices)")
762
+
763
+ onnx_path = Path(args.onnx)
764
+ if not onnx_path.exists():
765
+ parser.error(f"ONNX model not found: {onnx_path}")
766
+
767
+ stream_onnx_inference(
768
+ onnx_path=str(onnx_path),
769
+ config_path=args.config,
770
+ device_index=args.device_index,
771
+ sample_rate=args.sample_rate,
772
+ window_ms=args.window_ms,
773
+ hop_ms=args.hop_ms,
774
+ chunk_size=1600, # args.chunk_size,
775
+ cpu_only=args.cpu_only,
776
+ K=args.K,
777
+ tau=args.tau,
778
+ smooth_k=args.smooth_k,
779
+ min_peak_height=args.min_peak_height,
780
+ min_window_mass=args.min_window_mass,
781
+ min_sep_deg=args.min_sep_deg,
782
+ min_active_ratio=args.min_active_ratio,
783
+ max_sources=args.max_sources,
784
+ level_delta_on_db=args.level_delta_on_db,
785
+ level_delta_off_db=args.level_delta_off_db,
786
+ level_min_dbfs=args.level_min_dbfs,
787
+ level_ema_alpha=args.level_ema_alpha,
788
+ event_hold_ms=args.event_hold_ms,
789
+ min_R_clip=args.min_R_clip,
790
+ event_refractory_ms=args.event_refractory_ms,
791
+ onset_alpha=args.onset_alpha,
792
+ )
793
+
794
+
795
+ if __name__ == "__main__":
796
+ main()
silero_vad.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a35ebf52fd3ce5f1469b2a36158dba761bc47b973ea3382b3186ca15b1f5af28
3
+ size 1807522