Upload folder using huggingface_hub
Browse files- README.md +317 -0
- doa_model.onnx +3 -0
- features.py +186 -0
- onnx_stream_microphone.py +796 -0
- silero_vad.onnx +3 -0
README.md
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ONNX Real-Time DOA Streaming
|
| 2 |
+
|
| 3 |
+
Real-time Direction of Arrival (DOA) detection using an ONNX model with microphone streaming. This script processes audio from a multi-channel microphone array (ReSpeaker) in real-time and displays detected sound source directions.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
The script performs the following process:
|
| 8 |
+
|
| 9 |
+
1. **Audio Capture**: Streams audio from a 6-channel microphone array (ReSpeaker)
|
| 10 |
+
2. **Channel Selection**: Selects and reorders channels `[1, 4, 3, 2]` to get 4 channels
|
| 11 |
+
3. **Feature Extraction**: Computes STFT features (magnitude, phase, cosine, sine) from the audio
|
| 12 |
+
4. **ONNX Inference**: Runs the DOA model on GPU (CUDA) or CPU to get per-frame logits
|
| 13 |
+
5. **Histogram Aggregation**: Aggregates logits into a circular histogram of azimuth angles
|
| 14 |
+
6. **Peak Detection**: Finds peaks in the histogram to identify sound source directions
|
| 15 |
+
7. **Event Gating**: Filters detections based on audio level changes and coherence
|
| 16 |
+
8. **Visualization**: Displays detected directions on a polar plot in real-time
|
| 17 |
+
|
| 18 |
+
## Prerequisites
|
| 19 |
+
|
| 20 |
+
### Hardware
|
| 21 |
+
- **ReSpeaker 6-Mic Array** (or compatible multi-channel microphone)
|
| 22 |
+
- microphone:
|
| 23 |
+
positions:
|
| 24 |
+
- [0.0277, 0.0] # Mic 0: 0°
|
| 25 |
+
- [0.0, 0.0277] # Mic 1: 90°
|
| 26 |
+
- [-0.0277, 0.0] # Mic 2: 180°
|
| 27 |
+
- [0.0, -0.0277] # Mic 3: 270°
|
| 28 |
+
- **NVIDIA GPU** (optional, for faster inference)
|
| 29 |
+
|
| 30 |
+
### Software Dependencies
|
| 31 |
+
|
| 32 |
+
Install the required packages:
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
conda activate doaEnv
|
| 36 |
+
pip install onnxruntime-gpu # For GPU inference
|
| 37 |
+
# OR
|
| 38 |
+
pip install onnxruntime # For CPU-only inference
|
| 39 |
+
|
| 40 |
+
pip install pyaudio numpy matplotlib torch pyyaml
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
### ONNX Model
|
| 44 |
+
|
| 45 |
+
You need a converted ONNX model file. If you haven't converted your PyTorch model yet:
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
python convert_to_onnx.py --checkpoint models/basic/2025-11-06_22-37-00-6a5fbc92/last.pt --output models/basic/2025-11-06_22-37-00-6a5fbc92/last.onnx
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
## Quick Start
|
| 52 |
+
|
| 53 |
+
### 1. List Available Audio Devices
|
| 54 |
+
|
| 55 |
+
First, find your ReSpeaker device index:
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
python onnx_stream_microphone.py --list-devices
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
Look for a device named "ReSpeaker" or "Seeed" or containing "2886". Note the device index.
|
| 62 |
+
|
| 63 |
+
### 2. Stop PulseAudio (Required)
|
| 64 |
+
|
| 65 |
+
On Linux, PulseAudio often locks the ALSA devices. You need to temporarily stop it:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
pulseaudio --kill
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
**Note**: You can use the helper script `run_onnx_stream.sh` which automates this (see below).
|
| 72 |
+
|
| 73 |
+
### 3. Run the Streaming Script
|
| 74 |
+
|
| 75 |
+
Basic usage:
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
python onnx_stream_microphone.py \
|
| 79 |
+
--onnx models/basic/2025-11-06_22-37-00-6a5fbc92/last.onnx \
|
| 80 |
+
--device-index 9
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### 4. Restart PulseAudio (After Stopping)
|
| 84 |
+
|
| 85 |
+
After you're done, restart PulseAudio:
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
pulseaudio --start
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
## Using the Helper Script
|
| 92 |
+
|
| 93 |
+
A helper script automates PulseAudio management:
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
chmod +x run_onnx_stream.sh
|
| 97 |
+
./run_onnx_stream.sh --onnx models/basic/2025-11-06_22-37-00-6a5fbc92/last.onnx --device-index 9
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
This script will:
|
| 101 |
+
1. Stop PulseAudio
|
| 102 |
+
2. Run the streaming script
|
| 103 |
+
3. Restart PulseAudio when you exit (Ctrl+C)
|
| 104 |
+
|
| 105 |
+
## Command-Line Arguments
|
| 106 |
+
|
| 107 |
+
### Required Arguments
|
| 108 |
+
|
| 109 |
+
- `--onnx PATH`: Path to the ONNX model file
|
| 110 |
+
|
| 111 |
+
### Audio Configuration
|
| 112 |
+
|
| 113 |
+
- `--device-index INT`: Audio device index (use `--list-devices` to find it)
|
| 114 |
+
- `--sample-rate INT`: Sample rate in Hz (default: 16000)
|
| 115 |
+
- `--window-ms INT`: Analysis window length in milliseconds (default: 200)
|
| 116 |
+
- `--hop-ms INT`: Hop size (overlap) in milliseconds (default: 100)
|
| 117 |
+
- `--chunk-size INT`: Audio buffer chunk size (default: 1600)
|
| 118 |
+
- `--cpu-only`: Use CPU only (disable GPU inference)
|
| 119 |
+
- `--list-devices`: List all available audio input devices and exit
|
| 120 |
+
|
| 121 |
+
### Model Configuration
|
| 122 |
+
|
| 123 |
+
- `--config PATH`: Path to config.yaml (default: `configs/train.yaml`)
|
| 124 |
+
|
| 125 |
+
### Histogram Detection Parameters
|
| 126 |
+
|
| 127 |
+
These control how DOA peaks are detected from the model logits:
|
| 128 |
+
|
| 129 |
+
- `--K INT`: Number of azimuth bins (default: 72, should match model)
|
| 130 |
+
- `--tau FLOAT`: Softmax temperature for histogram (default: 0.8)
|
| 131 |
+
- `--smooth-k INT`: Histogram smoothing kernel size (default: 1)
|
| 132 |
+
- `--min-peak-height FLOAT`: Minimum peak height threshold (default: 0.10)
|
| 133 |
+
- `--min-window-mass FLOAT`: Minimum window mass for peak validation (default: 0.24)
|
| 134 |
+
- `--min-sep-deg FLOAT`: Minimum angular separation between peaks in degrees (default: 20.0)
|
| 135 |
+
- `--min-active-ratio FLOAT`: Minimum active frame ratio (default: 0.20)
|
| 136 |
+
- `--max-sources INT`: Maximum number of sources to detect (default: 3)
|
| 137 |
+
|
| 138 |
+
### Event Gate Parameters
|
| 139 |
+
|
| 140 |
+
These control when detections are considered valid (filtering noise):
|
| 141 |
+
|
| 142 |
+
- `--level-delta-on-db FLOAT`: Level increase threshold to open gate (default: 2.5)
|
| 143 |
+
- `--level-delta-off-db FLOAT`: Level decrease threshold to close gate (default: 1.0)
|
| 144 |
+
- `--level-min-dbfs FLOAT`: Minimum audio level in dBFS (default: -60.0)
|
| 145 |
+
- `--level-ema-alpha FLOAT`: Exponential moving average alpha for level tracking (default: 0.05)
|
| 146 |
+
- `--event-hold-ms INT`: Minimum time to keep gate open after detection (default: 300)
|
| 147 |
+
- `--min-R-clip FLOAT`: Minimum R_clip (coherence measure) to open gate (default: 0.18)
|
| 148 |
+
- `--event-refractory-ms INT`: Minimum time between gate state changes (default: 120)
|
| 149 |
+
|
| 150 |
+
### Onset Detection Parameters
|
| 151 |
+
|
| 152 |
+
- `--onset-alpha FLOAT`: EMA alpha for spectral flux tracking (default: 0.05)
|
| 153 |
+
|
| 154 |
+
## Example with Custom Parameters
|
| 155 |
+
|
| 156 |
+
```bash
|
| 157 |
+
python onnx_stream_microphone.py \
|
| 158 |
+
--onnx doa_model.onnx \
|
| 159 |
+
--device-index 9 \
|
| 160 |
+
--window-ms 400 \
|
| 161 |
+
--hop-ms 100 \
|
| 162 |
+
--K 72 \
|
| 163 |
+
--max-sources 2 \
|
| 164 |
+
--tau 0.8 \
|
| 165 |
+
--smooth-k 1 \
|
| 166 |
+
--min-peak-height 0.08 \
|
| 167 |
+
--min-window-mass 0.16 \
|
| 168 |
+
--min-sep-deg 22.5 \
|
| 169 |
+
--min-active-ratio 0.15 \
|
| 170 |
+
--level-delta-on-db 4.0 \
|
| 171 |
+
--level-delta-off-db 1.5 \
|
| 172 |
+
--level-min-dbfs -55.0 \
|
| 173 |
+
--level-ema-alpha 0.05 \
|
| 174 |
+
--event-hold-ms 320 \
|
| 175 |
+
--event-refractory-ms 200 \
|
| 176 |
+
--min-R-clip 0.30 \
|
| 177 |
+
--onset-alpha 0.05
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
## Understanding the Output
|
| 181 |
+
|
| 182 |
+
### Console Output
|
| 183 |
+
|
| 184 |
+
Each line shows:
|
| 185 |
+
|
| 186 |
+
```
|
| 187 |
+
[ 12.34s] LVL= -45.2 dBFS diff=+3.5 | FLUXz=2.10 COH=0.75 | GATE=OPEN | MODEL= 12.3ms HIST= 2.1ms | DOA(R=0.45, n=2) [45°, 180°]
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
- `[time]`: Elapsed time in seconds
|
| 191 |
+
- `LVL`: Audio level in dBFS
|
| 192 |
+
- `diff`: Level difference from background (dB)
|
| 193 |
+
- `FLUXz`: Spectral flux z-score (onset detection)
|
| 194 |
+
- `COH`: Inter-microphone coherence
|
| 195 |
+
- `GATE`: Gate state (OPEN/CLOSED)
|
| 196 |
+
- `MODEL`: Model inference time (ms)
|
| 197 |
+
- `HIST`: Histogram processing time (ms)
|
| 198 |
+
- `DOA(R=..., n=...)`: R_clip value and number of detected peaks
|
| 199 |
+
- `[angles]`: Detected azimuth angles in degrees
|
| 200 |
+
|
| 201 |
+
### Visual Output
|
| 202 |
+
|
| 203 |
+
A polar plot window shows:
|
| 204 |
+
- **Green lines**: Detected sound source directions
|
| 205 |
+
- **Line thickness**: Proportional to confidence score
|
| 206 |
+
- **Angle labels**: Azimuth in degrees (0° = North/front)
|
| 207 |
+
|
| 208 |
+
### Azimuth Convention
|
| 209 |
+
|
| 210 |
+
- **0°** = North (front of microphone)
|
| 211 |
+
- **90°** = East (right)
|
| 212 |
+
- **180°** = South (back)
|
| 213 |
+
- **270°** = West (left)
|
| 214 |
+
|
| 215 |
+
## How It Works
|
| 216 |
+
|
| 217 |
+
### 1. Audio Processing Pipeline
|
| 218 |
+
|
| 219 |
+
```
|
| 220 |
+
Microphone (6 ch) → Channel Selection [1,4,3,2] → 4-channel audio
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
### 2. Feature Extraction
|
| 224 |
+
|
| 225 |
+
For each analysis window:
|
| 226 |
+
- Compute STFT for all 4 channels
|
| 227 |
+
- Extract magnitude, phase, cosine, and sine components
|
| 228 |
+
- Result: `(T_frames, 12_features, F_freq_bins)`
|
| 229 |
+
|
| 230 |
+
### 3. Model Inference
|
| 231 |
+
|
| 232 |
+
- Batch process features through ONNX model
|
| 233 |
+
- Output: `(T_frames, K_bins)` logits per frame
|
| 234 |
+
- Each frame has K probability scores for different azimuth angles
|
| 235 |
+
|
| 236 |
+
### 4. Histogram Aggregation
|
| 237 |
+
|
| 238 |
+
- Apply softmax with temperature `tau` to logits
|
| 239 |
+
- Weight by circular coherence (R_clip)
|
| 240 |
+
- Aggregate across all frames into a single histogram
|
| 241 |
+
- Smooth the histogram
|
| 242 |
+
|
| 243 |
+
### 5. Peak Detection
|
| 244 |
+
|
| 245 |
+
- Find local maxima in the histogram
|
| 246 |
+
- Filter by minimum height, separation, and window mass
|
| 247 |
+
- Refine peak positions using parabolic interpolation
|
| 248 |
+
- Return up to `max_sources` peaks
|
| 249 |
+
|
| 250 |
+
### 6. Event Gating
|
| 251 |
+
|
| 252 |
+
- Track audio level with exponential moving average
|
| 253 |
+
- Open gate when:
|
| 254 |
+
- Level increases by `level_delta_on_db` OR
|
| 255 |
+
- Valid peaks detected AND R_clip > `min_R_clip`
|
| 256 |
+
- Close gate when level drops and no valid peaks
|
| 257 |
+
- Apply hold and refractory periods to prevent flickering
|
| 258 |
+
|
| 259 |
+
## Troubleshooting
|
| 260 |
+
|
| 261 |
+
### "Invalid number of channels" Error
|
| 262 |
+
|
| 263 |
+
**Problem**: Device reports 0 channels or PyAudio can't open it.
|
| 264 |
+
|
| 265 |
+
**Solution**:
|
| 266 |
+
1. Stop PulseAudio: `pulseaudio --kill`
|
| 267 |
+
2. Run the script
|
| 268 |
+
3. Restart PulseAudio: `pulseaudio --start`
|
| 269 |
+
|
| 270 |
+
Or use the helper script `run_onnx_stream.sh`.
|
| 271 |
+
|
| 272 |
+
### No Audio Detected
|
| 273 |
+
|
| 274 |
+
- Check microphone connections
|
| 275 |
+
- Verify device index with `--list-devices`
|
| 276 |
+
- Check audio levels (should be above `level_min_dbfs`)
|
| 277 |
+
- Adjust `level_delta_on_db` to be more sensitive
|
| 278 |
+
|
| 279 |
+
### GPU Not Used
|
| 280 |
+
|
| 281 |
+
- Verify CUDA is available: `python -c "import torch; print(torch.cuda.is_available())"`
|
| 282 |
+
- Install `onnxruntime-gpu` instead of `onnxruntime`
|
| 283 |
+
- Check that CUDA providers are listed in the model loading message
|
| 284 |
+
|
| 285 |
+
### Model Mismatch Errors
|
| 286 |
+
|
| 287 |
+
- Ensure `--K` matches the model's K value (usually 72)
|
| 288 |
+
- Check that the ONNX model was exported with the correct input shape
|
| 289 |
+
- Verify config.yaml matches training configuration
|
| 290 |
+
|
| 291 |
+
### Poor DOA Accuracy
|
| 292 |
+
|
| 293 |
+
- Increase `--window-ms` for longer analysis windows (more stable)
|
| 294 |
+
- Adjust `--min-peak-height` and `--min-window-mass` thresholds
|
| 295 |
+
- Tune `--tau` (lower = sharper peaks, higher = smoother)
|
| 296 |
+
- Check microphone array calibration and positioning
|
| 297 |
+
|
| 298 |
+
## Performance Tips
|
| 299 |
+
|
| 300 |
+
- **GPU Inference**: Use `onnxruntime-gpu` for 5-10x speedup
|
| 301 |
+
- **Window Size**: Larger windows (400ms) = more stable but higher latency
|
| 302 |
+
- **Hop Size**: Smaller hops (50ms) = more responsive but more computation
|
| 303 |
+
- **Batch Size**: The script uses batch_size=25 internally for efficient GPU usage
|
| 304 |
+
|
| 305 |
+
## Stopping the Script
|
| 306 |
+
|
| 307 |
+
Press **Ctrl+C** to stop the stream. The script will:
|
| 308 |
+
- Close the audio stream
|
| 309 |
+
- Close the visualization window
|
| 310 |
+
- Clean up resources
|
| 311 |
+
|
| 312 |
+
## Integration
|
| 313 |
+
|
| 314 |
+
To use this in your own code, see `onnx_doa_inference.py` which provides a standalone inference class that can be integrated into other projects.
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
|
doa_model.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:148f874c4fdc302a4d1808d6d0a45e1b3b40aeb6f8c6b2ef7e423710b2b28cba
|
| 3 |
+
size 785447
|
features.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from numpy.fft import rfft
|
| 3 |
+
from numpy.lib.stride_tricks import as_strided
|
| 4 |
+
from scipy.signal import get_window
|
| 5 |
+
|
| 6 |
+
def stft_multi(
|
| 7 |
+
x,
|
| 8 |
+
fs: float,
|
| 9 |
+
win_s: float = 0.032,
|
| 10 |
+
hop_s: float = 0.010,
|
| 11 |
+
nfft: int | None = None,
|
| 12 |
+
window: str | tuple | np.ndarray = "hann",
|
| 13 |
+
center: bool = True,
|
| 14 |
+
pad_mode: str = "reflect",
|
| 15 |
+
out_dtype = np.complex64,
|
| 16 |
+
):
|
| 17 |
+
"""
|
| 18 |
+
Multichannel STFT (vectorized).
|
| 19 |
+
Args
|
| 20 |
+
----
|
| 21 |
+
x : np.ndarray, shape (N, C) time-domain signal
|
| 22 |
+
fs : float, sampling rate (Hz)
|
| 23 |
+
win_s : float, window length in seconds (default 32 ms)
|
| 24 |
+
hop_s : float, hop length in seconds (default 10 ms)
|
| 25 |
+
nfft : int or None. If None, uses next power of two >= frame_len
|
| 26 |
+
window : str/tuple/array for scipy.signal.get_window or a length-L array
|
| 27 |
+
center : if True, pad by L//2 on both sides (librosa-style)
|
| 28 |
+
pad_mode: np.pad mode (e.g., "reflect", "constant")
|
| 29 |
+
out_dtype: dtype for STFT output (complex64 recommended)
|
| 30 |
+
|
| 31 |
+
Returns
|
| 32 |
+
-------
|
| 33 |
+
X : np.ndarray, shape (T, C, F) complex STFT
|
| 34 |
+
freqs: np.ndarray, shape (F,) frequency bins in Hz
|
| 35 |
+
times: np.ndarray, shape (T,) frame center times in seconds
|
| 36 |
+
"""
|
| 37 |
+
x = np.asarray(x)
|
| 38 |
+
if x.ndim == 1:
|
| 39 |
+
x = x[:, None] # (N,1)
|
| 40 |
+
assert x.ndim == 2, "x must be (samples, channels)"
|
| 41 |
+
N, C = x.shape
|
| 42 |
+
|
| 43 |
+
# Window & hop in samples
|
| 44 |
+
frame_len = int(round(win_s * fs))
|
| 45 |
+
hop = int(round(hop_s * fs))
|
| 46 |
+
if frame_len <= 0 or hop <= 0:
|
| 47 |
+
raise ValueError("win_s and hop_s must be > 0")
|
| 48 |
+
|
| 49 |
+
# FFT size
|
| 50 |
+
def _next_pow2(n):
|
| 51 |
+
return 1 << (int(n - 1).bit_length())
|
| 52 |
+
nfft = _next_pow2(frame_len) if nfft is None else int(nfft)
|
| 53 |
+
if nfft < frame_len:
|
| 54 |
+
raise ValueError("nfft must be >= frame_len")
|
| 55 |
+
|
| 56 |
+
# Window vector
|
| 57 |
+
if isinstance(window, np.ndarray):
|
| 58 |
+
w = window.astype(float, copy=False)
|
| 59 |
+
else:
|
| 60 |
+
w = get_window(window, frame_len, fftbins=True).astype(float)
|
| 61 |
+
if w.shape[0] != frame_len:
|
| 62 |
+
raise ValueError("Provided window length != frame_len")
|
| 63 |
+
|
| 64 |
+
# Optional centering (pad by L//2 on both sides)
|
| 65 |
+
pad = frame_len // 2 if center else 0
|
| 66 |
+
if pad > 0:
|
| 67 |
+
x_pad = np.pad(x, ((pad, pad), (0, 0)), mode=pad_mode)
|
| 68 |
+
else:
|
| 69 |
+
x_pad = x
|
| 70 |
+
|
| 71 |
+
Np = x_pad.shape[0]
|
| 72 |
+
if Np < frame_len:
|
| 73 |
+
# ensure at least one frame
|
| 74 |
+
x_pad = np.pad(x_pad, ((0, frame_len - Np), (0, 0)), mode=pad_mode)
|
| 75 |
+
Np = x_pad.shape[0]
|
| 76 |
+
|
| 77 |
+
# Number of frames
|
| 78 |
+
T = 1 + (Np - frame_len) // hop
|
| 79 |
+
if T <= 0:
|
| 80 |
+
raise ValueError("Signal too short for given window/hop")
|
| 81 |
+
|
| 82 |
+
# Stride-trick framing: (T, frame_len, C) view into x_pad
|
| 83 |
+
s_t, s_c = x_pad.strides # bytes per step in time/channel
|
| 84 |
+
frames = as_strided(
|
| 85 |
+
x_pad,
|
| 86 |
+
shape=(T, frame_len, C),
|
| 87 |
+
strides=(hop * s_t, s_t, s_c),
|
| 88 |
+
writeable=False,
|
| 89 |
+
)
|
| 90 |
+
# Reorder to (T, C, frame_len) to apply window & FFT along the last axis
|
| 91 |
+
frames = np.transpose(frames, (0, 2, 1)) # (T, C, L)
|
| 92 |
+
|
| 93 |
+
# Apply window (broadcast over T and C)
|
| 94 |
+
frames = frames * w[None, None, :]
|
| 95 |
+
|
| 96 |
+
# Batched real FFT along last axis -> (T, C, F)
|
| 97 |
+
X = rfft(frames, n=nfft, axis=-1).astype(out_dtype, copy=False)
|
| 98 |
+
|
| 99 |
+
# Frequency and time vectors
|
| 100 |
+
F = X.shape[-1]
|
| 101 |
+
freqs = (fs / nfft) * np.arange(F)
|
| 102 |
+
# Frame centers relative to original signal
|
| 103 |
+
if center:
|
| 104 |
+
# centers at sample indices: t*hop (librosa convention)
|
| 105 |
+
times = (np.arange(T) * hop) / fs
|
| 106 |
+
else:
|
| 107 |
+
# window centered at (frame_len/2) + t*hop
|
| 108 |
+
times = (np.arange(T) * hop + frame_len / 2.0) / fs
|
| 109 |
+
|
| 110 |
+
return X, freqs, times
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _wrap_to_2pi(x: np.ndarray) -> np.ndarray:
|
| 115 |
+
"""Wrap angles to [0, 2π)."""
|
| 116 |
+
return np.mod(x, 2.0 * np.pi)
|
| 117 |
+
|
| 118 |
+
def compute_mag_phase(
|
| 119 |
+
X: np.ndarray,
|
| 120 |
+
dtype=np.float32,
|
| 121 |
+
):
|
| 122 |
+
"""
|
| 123 |
+
Per-channel magnitude and absolute phase (wrapped to [0, 2π)).
|
| 124 |
+
|
| 125 |
+
Args
|
| 126 |
+
----
|
| 127 |
+
X : np.ndarray, shape (T, C, F), complex STFT
|
| 128 |
+
dtype: output dtype
|
| 129 |
+
|
| 130 |
+
Returns
|
| 131 |
+
-------
|
| 132 |
+
mag : np.ndarray, shape (T, C, F) = |X|
|
| 133 |
+
phase : np.ndarray, shape (T, C, F) = angle(X) in [0, 2π)
|
| 134 |
+
"""
|
| 135 |
+
assert X.ndim == 3, "X must be (T, C, F)"
|
| 136 |
+
mag = np.abs(X).astype(dtype, copy=False)
|
| 137 |
+
phase = _wrap_to_2pi(np.angle(X)).astype(dtype, copy=False)
|
| 138 |
+
return mag, phase
|
| 139 |
+
|
| 140 |
+
def compute_mag_phase_cos_sin(
|
| 141 |
+
X: np.ndarray,
|
| 142 |
+
dtype=np.float32,
|
| 143 |
+
):
|
| 144 |
+
"""
|
| 145 |
+
Concatenate per-channel magnitude, cos(phase), sin(phase).
|
| 146 |
+
|
| 147 |
+
Args
|
| 148 |
+
----
|
| 149 |
+
X : np.ndarray, shape (T, C, F), complex STFT
|
| 150 |
+
dtype: output dtype
|
| 151 |
+
|
| 152 |
+
Returns
|
| 153 |
+
-------
|
| 154 |
+
feats : np.ndarray, shape (T, 3*C, F)
|
| 155 |
+
Layout = [mag (C), cos(phase) (C), sin(phase) (C)]
|
| 156 |
+
where phase is angle(X) wrapped to [0, 2π).
|
| 157 |
+
"""
|
| 158 |
+
mag, phase = compute_mag_phase(X, dtype=dtype)
|
| 159 |
+
cos_phase = np.cos(phase).astype(dtype, copy=False)
|
| 160 |
+
sin_phase = np.sin(phase).astype(dtype, copy=False)
|
| 161 |
+
feats = np.concatenate([mag, cos_phase, sin_phase], axis=1)
|
| 162 |
+
return feats
|
| 163 |
+
|
| 164 |
+
def compute_real_imag_features(
|
| 165 |
+
X: np.ndarray,
|
| 166 |
+
dtype=np.float32,
|
| 167 |
+
):
|
| 168 |
+
"""
|
| 169 |
+
Concatenate per-channel real and imaginary parts.
|
| 170 |
+
|
| 171 |
+
Args
|
| 172 |
+
----
|
| 173 |
+
X : np.ndarray, shape (T, C, F), complex STFT
|
| 174 |
+
dtype: output dtype
|
| 175 |
+
|
| 176 |
+
Returns
|
| 177 |
+
-------
|
| 178 |
+
feats : np.ndarray, shape (T, 2*C, F)
|
| 179 |
+
Layout = [Re (C), Im (C)]
|
| 180 |
+
"""
|
| 181 |
+
assert X.ndim == 3, "X must be (T, C, F)"
|
| 182 |
+
real = X.real.astype(dtype, copy=False)
|
| 183 |
+
imag = X.imag.astype(dtype, copy=False)
|
| 184 |
+
feats = np.concatenate([real, imag], axis=1)
|
| 185 |
+
return feats
|
| 186 |
+
|
onnx_stream_microphone.py
ADDED
|
@@ -0,0 +1,796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Real-time DOA inference using ONNX model with microphone streaming.
|
| 4 |
+
Includes histogram-based detection, event gates, and onset detection.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# Add src directory to Python path
|
| 11 |
+
project_root = Path(__file__).parent
|
| 12 |
+
src_dir = project_root / "src"
|
| 13 |
+
if str(src_dir) not in sys.path:
|
| 14 |
+
sys.path.insert(0, str(src_dir))
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
import numpy as np
|
| 18 |
+
import time
|
| 19 |
+
import queue
|
| 20 |
+
import argparse
|
| 21 |
+
import pyaudio
|
| 22 |
+
import onnxruntime as ort
|
| 23 |
+
import yaml
|
| 24 |
+
from typing import Optional, Dict, List, Tuple
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
from matplotlib.patches import Circle
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
|
| 30 |
+
from mirokai_doa.features import stft_multi, compute_mag_phase_cos_sin
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# -------------------------
|
| 36 |
+
# Math helpers (numpy version)
|
| 37 |
+
# -------------------------
|
| 38 |
+
def _angles_deg_np(K: int):
|
| 39 |
+
bin_size = 360.0 / K
|
| 40 |
+
deg = (np.arange(K, dtype=np.float32) + 0.5) * bin_size
|
| 41 |
+
rad = deg * np.pi / 180.0
|
| 42 |
+
return deg, np.cos(rad), np.sin(rad), bin_size
|
| 43 |
+
|
| 44 |
+
def _softmax_temp_np(logits: np.ndarray, tau: float = 0.8) -> np.ndarray:
|
| 45 |
+
exp_logits = np.exp((logits - np.max(logits, axis=-1, keepdims=True)) / tau)
|
| 46 |
+
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
|
| 47 |
+
|
| 48 |
+
def _circular_window_sum_np(row: np.ndarray, idx: int, half_w: int) -> float:
|
| 49 |
+
K = row.size
|
| 50 |
+
if half_w <= 0:
|
| 51 |
+
return float(row[idx])
|
| 52 |
+
acc = 0.0
|
| 53 |
+
for d in range(-half_w, half_w + 1):
|
| 54 |
+
acc += float(row[(idx + d) % K])
|
| 55 |
+
return acc
|
| 56 |
+
|
| 57 |
+
def _parabolic_peak_refine_np(row: np.ndarray, k: int) -> float:
|
| 58 |
+
K = row.size
|
| 59 |
+
km1, kp1 = (k - 1) % K, (k + 1) % K
|
| 60 |
+
y1, y2, y3 = float(row[km1]), float(row[k]), float(row[kp1])
|
| 61 |
+
denom = (y1 - 2 * y2 + y3)
|
| 62 |
+
if abs(denom) < 1e-9:
|
| 63 |
+
return 0.0
|
| 64 |
+
delta = 0.5 * (y1 - y3) / denom
|
| 65 |
+
return float(max(min(delta, 0.5), -0.5))
|
| 66 |
+
|
| 67 |
+
def _min_circ_separation_bins(a: int, chosen: List[int], K: int) -> int:
|
| 68 |
+
if not chosen:
|
| 69 |
+
return K
|
| 70 |
+
dmin = K
|
| 71 |
+
for j in chosen:
|
| 72 |
+
d = abs(a - j)
|
| 73 |
+
d = min(d, K - d)
|
| 74 |
+
dmin = min(dmin, d)
|
| 75 |
+
return dmin
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# -------------------------
|
| 79 |
+
# Audio helpers
|
| 80 |
+
# -------------------------
|
| 81 |
+
def byte_to_float(data: bytes) -> np.ndarray:
|
| 82 |
+
samples = np.frombuffer(data, dtype=np.int16)
|
| 83 |
+
return samples.astype(np.float32) / 32768.0
|
| 84 |
+
|
| 85 |
+
def chunk_to_floatarray(data: bytes, channels: int) -> np.ndarray:
|
| 86 |
+
float_data = byte_to_float(data)
|
| 87 |
+
return float_data.reshape(-1, channels).T
|
| 88 |
+
|
| 89 |
+
def rms_dbfs(x: np.ndarray, eps: float = 1e-9) -> float:
|
| 90 |
+
val = np.sqrt((x * x).mean())
|
| 91 |
+
return 20.0 * np.log10(max(val, eps))
|
| 92 |
+
|
| 93 |
+
def frame_rms_energy(audio_buffer: np.ndarray, T: int) -> np.ndarray:
|
| 94 |
+
"""Split audio_buffer (C,N) into T equal segments; return per-frame RMS (normalized)."""
|
| 95 |
+
C, N = audio_buffer.shape
|
| 96 |
+
if T <= 0:
|
| 97 |
+
return np.ones(1, dtype=np.float32)
|
| 98 |
+
edges = np.linspace(0, N, T + 1, dtype=int)
|
| 99 |
+
e = []
|
| 100 |
+
for i in range(T):
|
| 101 |
+
seg = audio_buffer[:, edges[i]:edges[i+1]]
|
| 102 |
+
if seg.size == 0:
|
| 103 |
+
e.append(0.0)
|
| 104 |
+
else:
|
| 105 |
+
rms = np.sqrt((seg * seg).mean())
|
| 106 |
+
e.append(rms)
|
| 107 |
+
e = np.asarray(e, dtype=np.float32)
|
| 108 |
+
e = e / max(e.mean(), 1e-6)
|
| 109 |
+
return e
|
| 110 |
+
|
| 111 |
+
def spectral_flux_per_frame(audio_buffer: np.ndarray, T: int) -> np.ndarray:
|
| 112 |
+
"""Compute per-frame spectral flux across T segments from mono mix."""
|
| 113 |
+
C, N = audio_buffer.shape
|
| 114 |
+
if T <= 1:
|
| 115 |
+
return np.zeros((T,), dtype=np.float32)
|
| 116 |
+
mono = audio_buffer.mean(axis=0)
|
| 117 |
+
edges = np.linspace(0, N, T + 1, dtype=int)
|
| 118 |
+
mags = []
|
| 119 |
+
for i in range(T):
|
| 120 |
+
seg = mono[edges[i]:edges[i+1]]
|
| 121 |
+
if seg.size == 0:
|
| 122 |
+
mags.append(np.zeros(1, dtype=np.float32))
|
| 123 |
+
continue
|
| 124 |
+
win = np.hanning(len(seg)) if len(seg) > 8 else np.ones_like(seg)
|
| 125 |
+
S = np.fft.rfft(seg * win, n=len(seg))
|
| 126 |
+
mags.append(np.abs(S).astype(np.float32))
|
| 127 |
+
flux = np.zeros(T, dtype=np.float32)
|
| 128 |
+
for t in range(1, T):
|
| 129 |
+
a = mags[t-1]
|
| 130 |
+
b = mags[t]
|
| 131 |
+
L = min(len(a), len(b))
|
| 132 |
+
if L == 0:
|
| 133 |
+
flux[t] = 0.0
|
| 134 |
+
continue
|
| 135 |
+
diff = b[:L] - a[:L]
|
| 136 |
+
pos = np.maximum(diff, 0.0)
|
| 137 |
+
denom = np.sum(b[:L]) + 1e-6
|
| 138 |
+
flux[t] = float(np.sum(pos) / denom)
|
| 139 |
+
return flux
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# -------------------------
|
| 143 |
+
# Onset detector
|
| 144 |
+
# -------------------------
|
| 145 |
+
class OnsetDetector:
|
| 146 |
+
def __init__(self, alpha: float = 0.05):
|
| 147 |
+
self.alpha = float(alpha)
|
| 148 |
+
self.mu = 0.0
|
| 149 |
+
self.var = 1.0
|
| 150 |
+
self.inited = False
|
| 151 |
+
|
| 152 |
+
def update_flux(self, flux_recent: float) -> float:
|
| 153 |
+
if not self.inited:
|
| 154 |
+
self.mu = flux_recent
|
| 155 |
+
self.var = 1e-3 + abs(flux_recent)
|
| 156 |
+
self.inited = True
|
| 157 |
+
delta = flux_recent - self.mu
|
| 158 |
+
self.mu += self.alpha * delta
|
| 159 |
+
self.var = (1 - self.alpha) * self.var + self.alpha * delta * delta
|
| 160 |
+
sigma = max(np.sqrt(self.var), 1e-6)
|
| 161 |
+
z = (flux_recent - self.mu) / sigma
|
| 162 |
+
return float(z)
|
| 163 |
+
|
| 164 |
+
@staticmethod
|
| 165 |
+
def last_segment_coherence(audio_buffer: np.ndarray, T: int,
|
| 166 |
+
pairs: List[Tuple[int,int]] = [(0,1),(0,2),(0,3)]) -> float:
|
| 167 |
+
C, N = audio_buffer.shape
|
| 168 |
+
if T < 1:
|
| 169 |
+
return 0.0
|
| 170 |
+
edges = np.linspace(0, N, T + 1, dtype=int)
|
| 171 |
+
s0, s1 = int(edges[-2]), int(edges[-1])
|
| 172 |
+
seg = audio_buffer[:, s0:s1]
|
| 173 |
+
if seg.shape[1] < 16:
|
| 174 |
+
return 0.0
|
| 175 |
+
rmax = 0.0
|
| 176 |
+
for (i,j) in pairs:
|
| 177 |
+
xi = seg[i] - seg[i].mean()
|
| 178 |
+
xj = seg[j] - seg[j].mean()
|
| 179 |
+
denom = (np.linalg.norm(xi) * np.linalg.norm(xj) + 1e-9)
|
| 180 |
+
r = float(np.dot(xi, xj) / denom)
|
| 181 |
+
rmax = max(rmax, abs(r))
|
| 182 |
+
return rmax
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# -------------------------
|
| 186 |
+
# Histogram DOA detector (numpy/torch hybrid)
|
| 187 |
+
# -------------------------
|
| 188 |
+
class HistDOADetector:
|
| 189 |
+
def __init__(
|
| 190 |
+
self,
|
| 191 |
+
K: int = 72,
|
| 192 |
+
tau: float = 0.8,
|
| 193 |
+
gamma: float = 1.5,
|
| 194 |
+
smooth_k: int = 1,
|
| 195 |
+
window_bins: int = 1,
|
| 196 |
+
min_peak_height: float = 0.10,
|
| 197 |
+
min_window_mass: float = 0.24,
|
| 198 |
+
min_sep_deg: float = 20.0,
|
| 199 |
+
min_active_ratio: float = 0.20,
|
| 200 |
+
max_sources: int = 3,
|
| 201 |
+
device: str = "cpu",
|
| 202 |
+
):
|
| 203 |
+
self.K = int(K)
|
| 204 |
+
self.tau = float(tau)
|
| 205 |
+
self.gamma = float(gamma)
|
| 206 |
+
self.smooth_k = int(smooth_k)
|
| 207 |
+
self.window_bins = int(window_bins)
|
| 208 |
+
self.min_peak_height = float(min_peak_height)
|
| 209 |
+
self.min_window_mass = float(min_window_mass)
|
| 210 |
+
self.min_sep_deg = float(min_sep_deg)
|
| 211 |
+
self.min_active_ratio = float(min_active_ratio)
|
| 212 |
+
self.max_sources = int(max_sources)
|
| 213 |
+
self.device = torch.device(device)
|
| 214 |
+
self._deg, self._cos, self._sin, self._bin_size = self._angles_deg(self.K)
|
| 215 |
+
|
| 216 |
+
def _angles_deg(self, K: int):
|
| 217 |
+
bin_size = 360.0 / K
|
| 218 |
+
deg = torch.arange(K, device=self.device, dtype=torch.float32) + 0.5
|
| 219 |
+
deg = deg * bin_size
|
| 220 |
+
rad = deg * math.pi / 180.0
|
| 221 |
+
return deg, torch.cos(rad), torch.sin(rad), bin_size
|
| 222 |
+
|
| 223 |
+
def _aggregate_histogram(self, logits: np.ndarray, mask: np.ndarray) -> Tuple[np.ndarray, float, float]:
|
| 224 |
+
"""Aggregate histogram from logits and VAD mask."""
|
| 225 |
+
logits_t = torch.from_numpy(logits).float().to(self.device)
|
| 226 |
+
mask_t = torch.from_numpy(mask).float().to(self.device)
|
| 227 |
+
|
| 228 |
+
probs = F.softmax(logits_t / self.tau, dim=-1) # [T,K]
|
| 229 |
+
T = probs.shape[0]
|
| 230 |
+
m = mask_t
|
| 231 |
+
|
| 232 |
+
# Weighted histogram
|
| 233 |
+
x = torch.matmul(probs, self._cos)
|
| 234 |
+
y = torch.matmul(probs, self._sin)
|
| 235 |
+
R_t = torch.clamp(torch.sqrt(x * x + y * y), 0, 1)
|
| 236 |
+
w = m * (R_t ** self.gamma)
|
| 237 |
+
|
| 238 |
+
if w.sum() <= 0:
|
| 239 |
+
w = torch.ones_like(w) * 1e-6
|
| 240 |
+
|
| 241 |
+
hist = torch.matmul(w, probs)
|
| 242 |
+
hist = hist / hist.sum().clamp_min(1e-8)
|
| 243 |
+
|
| 244 |
+
if self.smooth_k > 0:
|
| 245 |
+
s = self.smooth_k
|
| 246 |
+
pad = torch.cat([hist[-s:], hist, hist[:s]], dim=0).view(1, 1, -1)
|
| 247 |
+
kernel = torch.ones(1, 1, 2 * s + 1, device=self.device) / (2 * s + 1)
|
| 248 |
+
hist = F.conv1d(pad, kernel, padding=0).view(-1)
|
| 249 |
+
|
| 250 |
+
X = torch.dot(hist, self._cos)
|
| 251 |
+
Y = torch.dot(hist, self._sin)
|
| 252 |
+
R_clip = float(torch.sqrt(X * X + Y * Y).item())
|
| 253 |
+
active_ratio = float(m.mean().item())
|
| 254 |
+
return hist.detach().cpu().numpy(), active_ratio, R_clip
|
| 255 |
+
|
| 256 |
+
def _pick_peaks(self, hist: np.ndarray) -> List[Dict[str, float]]:
|
| 257 |
+
"""Pick peaks from histogram."""
|
| 258 |
+
hist_t = torch.from_numpy(hist).float()
|
| 259 |
+
K = self.K
|
| 260 |
+
bin_size = self._bin_size
|
| 261 |
+
|
| 262 |
+
left = torch.roll(hist_t, 1, 0)
|
| 263 |
+
right = torch.roll(hist_t, -1, 0)
|
| 264 |
+
cand_idxs = ((hist_t > left) & (hist_t > right)).nonzero(as_tuple=False).flatten().tolist()
|
| 265 |
+
cand_idxs.sort(key=lambda i: float(hist_t[i].item()), reverse=True)
|
| 266 |
+
|
| 267 |
+
chosen, out = [], []
|
| 268 |
+
min_sep_bins = max(1, int(round(self.min_sep_deg / bin_size)))
|
| 269 |
+
|
| 270 |
+
for idx in cand_idxs:
|
| 271 |
+
if _min_circ_separation_bins(idx, chosen, K) < min_sep_bins:
|
| 272 |
+
continue
|
| 273 |
+
if float(hist_t[idx].item()) < self.min_peak_height:
|
| 274 |
+
continue
|
| 275 |
+
mass = _circular_window_sum_np(hist, idx, self.window_bins)
|
| 276 |
+
if mass < self.min_window_mass:
|
| 277 |
+
continue
|
| 278 |
+
delta = _parabolic_peak_refine_np(hist, idx)
|
| 279 |
+
angle_deg = ((idx + 0.5 + delta) * bin_size) % 360.0
|
| 280 |
+
out.append({"azimuth_deg": angle_deg, "score": float(mass)})
|
| 281 |
+
chosen.append(idx)
|
| 282 |
+
if len(out) >= self.max_sources:
|
| 283 |
+
break
|
| 284 |
+
return out
|
| 285 |
+
|
| 286 |
+
def detect(self, logits: np.ndarray) -> Dict[str, any]:
|
| 287 |
+
"""Detect DOA from logits (no VAD separation)."""
|
| 288 |
+
# Use all frames (no VAD masking)
|
| 289 |
+
mask = np.ones(logits.shape[0], dtype=np.float32)
|
| 290 |
+
|
| 291 |
+
hist, active_ratio, R_clip = self._aggregate_histogram(logits, mask)
|
| 292 |
+
|
| 293 |
+
peaks = self._pick_peaks(hist) if active_ratio >= self.min_active_ratio else []
|
| 294 |
+
|
| 295 |
+
bins_deg = (np.arange(self.K) + 0.5) * (360.0 / self.K)
|
| 296 |
+
return {
|
| 297 |
+
"peaks": peaks,
|
| 298 |
+
"active_ratio": active_ratio,
|
| 299 |
+
"R_clip": R_clip,
|
| 300 |
+
"hist": hist,
|
| 301 |
+
"bins_deg": bins_deg,
|
| 302 |
+
"has_event": bool(peaks),
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# -------------------------
|
| 307 |
+
# Event gate
|
| 308 |
+
# -------------------------
|
| 309 |
+
class LevelChangeGate:
|
| 310 |
+
def __init__(
|
| 311 |
+
self,
|
| 312 |
+
delta_on_db: float = 2.5,
|
| 313 |
+
delta_off_db: float = 1.0,
|
| 314 |
+
level_min_dbfs: float = -60.0,
|
| 315 |
+
ema_alpha: float = 0.05,
|
| 316 |
+
min_R_clip: float = 0.18,
|
| 317 |
+
hold_ms: int = 300,
|
| 318 |
+
refractory_ms: int = 120
|
| 319 |
+
):
|
| 320 |
+
self.delta_on_db = float(delta_on_db)
|
| 321 |
+
self.delta_off_db = float(delta_off_db)
|
| 322 |
+
self.level_min_dbfs = float(level_min_dbfs)
|
| 323 |
+
self.ema_alpha = float(ema_alpha)
|
| 324 |
+
self.min_R_clip = float(min_R_clip)
|
| 325 |
+
self.hold_s = float(hold_ms) / 1000.0
|
| 326 |
+
self.refractory_s = float(refractory_ms) / 1000.0
|
| 327 |
+
self.bg_dbfs = None
|
| 328 |
+
self.active = False
|
| 329 |
+
self.last_change_time = 0.0
|
| 330 |
+
|
| 331 |
+
def update(self, level_dbfs: float, now_s: float,
|
| 332 |
+
peaks_count: int, R_clip_max: float):
|
| 333 |
+
if self.bg_dbfs is None:
|
| 334 |
+
self.bg_dbfs = level_dbfs
|
| 335 |
+
diff_db = level_dbfs - self.bg_dbfs
|
| 336 |
+
|
| 337 |
+
want_open = (
|
| 338 |
+
(now_s - self.last_change_time) >= self.refractory_s and
|
| 339 |
+
((level_dbfs > self.level_min_dbfs and diff_db >= self.delta_on_db) or
|
| 340 |
+
(peaks_count > 0 and R_clip_max >= self.min_R_clip))
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
if not self.active:
|
| 344 |
+
if want_open:
|
| 345 |
+
self.active = True
|
| 346 |
+
self.last_change_time = now_s
|
| 347 |
+
else:
|
| 348 |
+
if (now_s - self.last_change_time) >= self.hold_s:
|
| 349 |
+
want_close = (
|
| 350 |
+
(diff_db <= self.delta_off_db) and
|
| 351 |
+
(peaks_count == 0 or R_clip_max < self.min_R_clip)
|
| 352 |
+
)
|
| 353 |
+
if want_close:
|
| 354 |
+
self.active = False
|
| 355 |
+
self.last_change_time = now_s
|
| 356 |
+
|
| 357 |
+
self.bg_dbfs = (1.0 - self.ema_alpha) * self.bg_dbfs + self.ema_alpha * level_dbfs
|
| 358 |
+
return self.active, diff_db
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# -------------------------
|
| 362 |
+
# ONNX Inference
|
| 363 |
+
# -------------------------
|
| 364 |
+
class ONNXDOAStreaming:
|
| 365 |
+
def __init__(
|
| 366 |
+
self,
|
| 367 |
+
onnx_path: str,
|
| 368 |
+
config_path: Optional[str] = None,
|
| 369 |
+
providers: Optional[list] = None
|
| 370 |
+
):
|
| 371 |
+
if config_path is None:
|
| 372 |
+
config_path = project_root / "configs" / "train.yaml"
|
| 373 |
+
with open(config_path, 'r') as f:
|
| 374 |
+
self.config = yaml.safe_load(f)
|
| 375 |
+
|
| 376 |
+
self.features_cfg = self.config.get('features', {})
|
| 377 |
+
self.sr = self.features_cfg.get('sr', 16000)
|
| 378 |
+
self.win_s = self.features_cfg.get('win_s', 0.032)
|
| 379 |
+
self.hop_s = self.features_cfg.get('hop_s', 0.010)
|
| 380 |
+
self.nfft = self.features_cfg.get('nfft', 1024)
|
| 381 |
+
self.K = self.features_cfg.get('K', 72)
|
| 382 |
+
|
| 383 |
+
if providers is None:
|
| 384 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 385 |
+
|
| 386 |
+
sess_options = ort.SessionOptions()
|
| 387 |
+
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 388 |
+
|
| 389 |
+
self.session = ort.InferenceSession(onnx_path, sess_options=sess_options, providers=providers)
|
| 390 |
+
self.input_name = self.session.get_inputs()[0].name
|
| 391 |
+
self.output_name = self.session.get_outputs()[0].name
|
| 392 |
+
input_shape = self.session.get_inputs()[0].shape
|
| 393 |
+
self.is_doa_model = input_shape[-1] == 513 if len(input_shape) == 4 else False
|
| 394 |
+
|
| 395 |
+
print(f"ONNX Model loaded: {onnx_path}")
|
| 396 |
+
print(f" Input shape: {input_shape}")
|
| 397 |
+
print(f" Model type: {'DoAEstimator' if self.is_doa_model else 'TFPoolClassifierNoCond'}")
|
| 398 |
+
print(f" Providers: {self.session.get_providers()}")
|
| 399 |
+
|
| 400 |
+
def compute_features(self, mixture: np.ndarray) -> np.ndarray:
|
| 401 |
+
if mixture.ndim == 1:
|
| 402 |
+
raise ValueError("Mixture must be multichannel (4 channels)")
|
| 403 |
+
if mixture.shape[0] != 4 and mixture.shape[1] == 4:
|
| 404 |
+
mixture = mixture.T
|
| 405 |
+
|
| 406 |
+
if mixture.shape[0] != 4:
|
| 407 |
+
raise ValueError(f"Expected 4 channels, got {mixture.shape[0]}")
|
| 408 |
+
|
| 409 |
+
x4 = mixture.astype(np.float32)
|
| 410 |
+
X, freqs, times = stft_multi(x4.T, fs=self.sr, win_s=self.win_s, hop_s=self.hop_s,
|
| 411 |
+
nfft=self.nfft, window="hann", center=True, pad_mode="reflect")
|
| 412 |
+
feats = compute_mag_phase_cos_sin(X, dtype=np.float32)
|
| 413 |
+
return feats
|
| 414 |
+
|
| 415 |
+
def inference_batch(self, feats: np.ndarray, batch_size: int = 25) -> np.ndarray:
|
| 416 |
+
T_frames, C_feat, F = feats.shape
|
| 417 |
+
assert C_feat == 12, f"Expected 12 feature channels, got {C_feat}"
|
| 418 |
+
|
| 419 |
+
all_logits = []
|
| 420 |
+
for start_idx in range(0, T_frames, batch_size):
|
| 421 |
+
end_idx = min(start_idx + batch_size, T_frames)
|
| 422 |
+
batch_feats = feats[start_idx:end_idx]
|
| 423 |
+
batch_T = batch_feats.shape[0]
|
| 424 |
+
|
| 425 |
+
if batch_T < batch_size:
|
| 426 |
+
padding = np.zeros((batch_size - batch_T, C_feat, F), dtype=batch_feats.dtype)
|
| 427 |
+
batch_feats = np.concatenate([batch_feats, padding], axis=0)
|
| 428 |
+
|
| 429 |
+
feats_tensor = batch_feats.transpose(1, 0, 2)[np.newaxis, ...]
|
| 430 |
+
outputs = self.session.run([self.output_name], {self.input_name: feats_tensor.astype(np.float32)})
|
| 431 |
+
batch_logits = outputs[0]
|
| 432 |
+
|
| 433 |
+
if batch_logits.ndim == 2:
|
| 434 |
+
if batch_logits.shape[0] == 1 and batch_logits.shape[1] == self.K:
|
| 435 |
+
batch_logits = np.tile(batch_logits, (batch_T, 1))
|
| 436 |
+
elif batch_logits.shape[0] == 1:
|
| 437 |
+
batch_logits = batch_logits[0]
|
| 438 |
+
else:
|
| 439 |
+
batch_logits = batch_logits[:batch_T]
|
| 440 |
+
elif batch_logits.ndim == 3:
|
| 441 |
+
batch_logits = batch_logits[0, :batch_T]
|
| 442 |
+
|
| 443 |
+
all_logits.append(batch_logits)
|
| 444 |
+
|
| 445 |
+
return np.concatenate(all_logits, axis=0)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
# -------------------------
|
| 449 |
+
# Visualization
|
| 450 |
+
# -------------------------
|
| 451 |
+
class CurrentLineVisualizer:
|
| 452 |
+
def __init__(self, title: str = "Current DOA"):
|
| 453 |
+
self.fig = plt.figure(figsize=(7.5, 7.5))
|
| 454 |
+
self.ax = self.fig.add_subplot(111, projection='polar')
|
| 455 |
+
self._setup_axes(title)
|
| 456 |
+
plt.ion()
|
| 457 |
+
plt.show(block=False)
|
| 458 |
+
|
| 459 |
+
def _setup_axes(self, title: str):
|
| 460 |
+
self.ax.clear()
|
| 461 |
+
self.ax.set_title(title, fontsize=13, fontweight='bold', pad=16)
|
| 462 |
+
self.ax.set_theta_zero_location('N')
|
| 463 |
+
self.ax.set_theta_direction(-1)
|
| 464 |
+
self.ax.set_thetalim(0, 2*np.pi)
|
| 465 |
+
self.ax.set_ylim(0, 1.05)
|
| 466 |
+
self.ax.set_yticklabels([])
|
| 467 |
+
self.ax.add_patch(Circle((0, 0), 1.0, fill=False, color='gray', linestyle='--', linewidth=1, alpha=0.5))
|
| 468 |
+
self.ax.grid(alpha=0.2)
|
| 469 |
+
|
| 470 |
+
def update(self, peaks: List[Dict]):
|
| 471 |
+
self._setup_axes("Current DOA")
|
| 472 |
+
|
| 473 |
+
for pk in peaks[:3]:
|
| 474 |
+
az = float(pk["azimuth_deg"])
|
| 475 |
+
sc = float(pk.get("score", 0.2))
|
| 476 |
+
lw = 2.0 + 5.0 * float(np.clip(sc, 0.0, 0.6))
|
| 477 |
+
theta = np.deg2rad(az)
|
| 478 |
+
self.ax.plot([theta, theta], [0.0, 1.0], color='tab:green', linewidth=lw, solid_capstyle='round')
|
| 479 |
+
self.ax.text(theta, 1.02, f"{az:.0f}°", ha='center', va='bottom', fontsize=10,
|
| 480 |
+
color='tab:green', fontweight='bold')
|
| 481 |
+
|
| 482 |
+
self.fig.canvas.draw_idle()
|
| 483 |
+
self.fig.canvas.flush_events()
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
# -------------------------
|
| 489 |
+
# Main streaming function
|
| 490 |
+
# -------------------------
|
| 491 |
+
def stream_onnx_inference(
|
| 492 |
+
onnx_path: str,
|
| 493 |
+
config_path: Optional[str] = None,
|
| 494 |
+
device_index: Optional[int] = None,
|
| 495 |
+
sample_rate: int = 16000,
|
| 496 |
+
window_ms: int = 200,
|
| 497 |
+
hop_ms: int = 100,
|
| 498 |
+
chunk_size: int = 1600,
|
| 499 |
+
cpu_only: bool = False,
|
| 500 |
+
# Histogram params
|
| 501 |
+
K: int = 72,
|
| 502 |
+
tau: float = 0.8,
|
| 503 |
+
smooth_k: int = 1,
|
| 504 |
+
min_peak_height: float = 0.10,
|
| 505 |
+
min_window_mass: float = 0.24,
|
| 506 |
+
min_sep_deg: float = 20.0,
|
| 507 |
+
min_active_ratio: float = 0.20,
|
| 508 |
+
max_sources: int = 3,
|
| 509 |
+
# Event gate params
|
| 510 |
+
level_delta_on_db: float = 2.5,
|
| 511 |
+
level_delta_off_db: float = 1.0,
|
| 512 |
+
level_min_dbfs: float = -60.0,
|
| 513 |
+
level_ema_alpha: float = 0.05,
|
| 514 |
+
event_hold_ms: int = 300,
|
| 515 |
+
min_R_clip: float = 0.18,
|
| 516 |
+
event_refractory_ms: int = 120,
|
| 517 |
+
# Onset params
|
| 518 |
+
onset_alpha: float = 0.05,
|
| 519 |
+
):
|
| 520 |
+
"""Stream inference from microphone using ONNX model."""
|
| 521 |
+
|
| 522 |
+
providers = ['CPUExecutionProvider'] if cpu_only else ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 523 |
+
infer = ONNXDOAStreaming(onnx_path, config_path, providers=providers)
|
| 524 |
+
|
| 525 |
+
# Override K if provided
|
| 526 |
+
if K != infer.K:
|
| 527 |
+
print(f"Warning: K mismatch. Model K={infer.K}, requested K={K}. Using model K.")
|
| 528 |
+
K = infer.K
|
| 529 |
+
|
| 530 |
+
det = HistDOADetector(
|
| 531 |
+
K=K, tau=tau, gamma=1.5, smooth_k=smooth_k,
|
| 532 |
+
window_bins=1, min_peak_height=min_peak_height, min_window_mass=min_window_mass,
|
| 533 |
+
min_sep_deg=min_sep_deg, min_active_ratio=min_active_ratio, max_sources=max_sources,
|
| 534 |
+
device="cuda" if not cpu_only and torch.cuda.is_available() else "cpu"
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
gate = LevelChangeGate(
|
| 538 |
+
delta_on_db=level_delta_on_db, delta_off_db=level_delta_off_db,
|
| 539 |
+
level_min_dbfs=level_min_dbfs, ema_alpha=level_ema_alpha,
|
| 540 |
+
min_R_clip=min_R_clip,
|
| 541 |
+
hold_ms=event_hold_ms, refractory_ms=event_refractory_ms
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
onset = OnsetDetector(alpha=onset_alpha)
|
| 545 |
+
visualizer = CurrentLineVisualizer()
|
| 546 |
+
|
| 547 |
+
window_samples = int(sample_rate * window_ms / 1000)
|
| 548 |
+
hop_samples = int(sample_rate * hop_ms / 1000)
|
| 549 |
+
|
| 550 |
+
p = pyaudio.PyAudio()
|
| 551 |
+
|
| 552 |
+
if device_index is None:
|
| 553 |
+
for i in range(p.get_device_count()):
|
| 554 |
+
info = p.get_device_info_by_index(i)
|
| 555 |
+
name = info['name'].lower()
|
| 556 |
+
# Check by name first (PulseAudio might hide channels)
|
| 557 |
+
if 'respeaker' in name or 'seeed' in name or '2886' in name:
|
| 558 |
+
device_index = i
|
| 559 |
+
print(f"Auto-detected ReSpeaker at device {i}: {info['name']}")
|
| 560 |
+
break
|
| 561 |
+
|
| 562 |
+
if device_index is None:
|
| 563 |
+
print("\n[Audio] Could not auto-detect Respeaker. Use --device-index or --list-devices.\n")
|
| 564 |
+
p.terminate()
|
| 565 |
+
return
|
| 566 |
+
|
| 567 |
+
# Check device info
|
| 568 |
+
device_info = p.get_device_info_by_index(device_index)
|
| 569 |
+
print(f"Device info: {device_info['name']}")
|
| 570 |
+
print(f" Max input channels: {device_info['maxInputChannels']}")
|
| 571 |
+
print(f" Default sample rate: {device_info['defaultSampleRate']:.0f} Hz")
|
| 572 |
+
|
| 573 |
+
# If device shows 0 channels, it's likely managed by PulseAudio
|
| 574 |
+
# We'll still try to open it - sometimes it works despite the report
|
| 575 |
+
if device_info['maxInputChannels'] == 0:
|
| 576 |
+
print(" Warning: Device reports 0 channels (may be managed by PulseAudio)")
|
| 577 |
+
print(" Attempting to open anyway...")
|
| 578 |
+
|
| 579 |
+
CHANNELS = 6
|
| 580 |
+
RAW_CHANNELS = [1, 4, 3, 2] # your requested order
|
| 581 |
+
FORMAT = pyaudio.paInt16
|
| 582 |
+
|
| 583 |
+
audio_buffer = np.zeros((4, window_samples), dtype=np.float32)
|
| 584 |
+
buffer_fill = 0
|
| 585 |
+
start_time = time.time()
|
| 586 |
+
|
| 587 |
+
audio_queue = queue.Queue()
|
| 588 |
+
stream_closed = False
|
| 589 |
+
|
| 590 |
+
def _fill_buffer(in_data, frame_count, time_info, status_flags):
|
| 591 |
+
if not stream_closed:
|
| 592 |
+
audio_queue.put(in_data)
|
| 593 |
+
return None, pyaudio.paContinue
|
| 594 |
+
|
| 595 |
+
try:
|
| 596 |
+
# Try to open the stream - PyAudio will validate channels
|
| 597 |
+
stream = p.open(
|
| 598 |
+
format=FORMAT,
|
| 599 |
+
channels=CHANNELS,
|
| 600 |
+
rate=sample_rate,
|
| 601 |
+
input=True,
|
| 602 |
+
input_device_index=device_index,
|
| 603 |
+
frames_per_buffer=chunk_size,
|
| 604 |
+
stream_callback=_fill_buffer
|
| 605 |
+
)
|
| 606 |
+
print(" Successfully opened audio stream with 6 channels")
|
| 607 |
+
except Exception as e:
|
| 608 |
+
print(f"\n[Audio] Could not open input device (index {device_index}).")
|
| 609 |
+
print(f" Error: {e}")
|
| 610 |
+
print("\n The ReSpeaker device is likely locked by PulseAudio.")
|
| 611 |
+
print(" Solutions:")
|
| 612 |
+
print(" 1. Temporarily stop PulseAudio: pulseaudio --kill")
|
| 613 |
+
print(" 2. Then restart it after: pulseaudio --start")
|
| 614 |
+
print(" 3. Or configure PulseAudio to allow direct ALSA access\n")
|
| 615 |
+
p.terminate()
|
| 616 |
+
return
|
| 617 |
+
|
| 618 |
+
stream.start_stream()
|
| 619 |
+
print(f"\n[Streaming] Started. Window: {window_ms}ms, Hop: {hop_ms}ms")
|
| 620 |
+
print(" Press Ctrl+C to stop.\n")
|
| 621 |
+
|
| 622 |
+
try:
|
| 623 |
+
while True:
|
| 624 |
+
try:
|
| 625 |
+
data = audio_queue.get(timeout=1.0)
|
| 626 |
+
except queue.Empty:
|
| 627 |
+
continue
|
| 628 |
+
|
| 629 |
+
chunk_all = chunk_to_floatarray(data, CHANNELS) # (6, N)
|
| 630 |
+
audio_chunk = chunk_all[RAW_CHANNELS, :] # (4, N)
|
| 631 |
+
n = audio_chunk.shape[1]
|
| 632 |
+
|
| 633 |
+
if buffer_fill + n <= window_samples:
|
| 634 |
+
audio_buffer[:, buffer_fill:buffer_fill + n] = audio_chunk
|
| 635 |
+
buffer_fill += n
|
| 636 |
+
continue
|
| 637 |
+
|
| 638 |
+
remaining = window_samples - buffer_fill
|
| 639 |
+
if remaining > 0:
|
| 640 |
+
audio_buffer[:, buffer_fill:] = audio_chunk[:, :remaining]
|
| 641 |
+
buffer_fill = window_samples
|
| 642 |
+
|
| 643 |
+
# Inference
|
| 644 |
+
t0 = time.perf_counter()
|
| 645 |
+
feats = infer.compute_features(audio_buffer)
|
| 646 |
+
logits = infer.inference_batch(feats)
|
| 647 |
+
t_model = (time.perf_counter() - t0) * 1000.0
|
| 648 |
+
|
| 649 |
+
T = logits.shape[0]
|
| 650 |
+
energies = frame_rms_energy(audio_buffer, T)
|
| 651 |
+
flux = spectral_flux_per_frame(audio_buffer, T)
|
| 652 |
+
flux_recent = float(max(flux[-1], flux[-2] if T >= 2 else 0.0))
|
| 653 |
+
flux_z = onset.update_flux(flux_recent)
|
| 654 |
+
coh = OnsetDetector.last_segment_coherence(audio_buffer, T)
|
| 655 |
+
|
| 656 |
+
# DOA detection (no VAD)
|
| 657 |
+
t1 = time.perf_counter()
|
| 658 |
+
det_result = det.detect(logits)
|
| 659 |
+
t_hist = (time.perf_counter() - t1) * 1000.0
|
| 660 |
+
|
| 661 |
+
peaks = det_result["peaks"]
|
| 662 |
+
peaks_count = len(peaks)
|
| 663 |
+
Rmax = det_result["R_clip"]
|
| 664 |
+
|
| 665 |
+
level = rms_dbfs(audio_buffer)
|
| 666 |
+
now = time.time() - start_time
|
| 667 |
+
|
| 668 |
+
gate_open, diff_db = gate.update(level_dbfs=level, now_s=now,
|
| 669 |
+
peaks_count=peaks_count, R_clip_max=Rmax)
|
| 670 |
+
|
| 671 |
+
if gate_open:
|
| 672 |
+
visualizer.update(peaks)
|
| 673 |
+
gate_str = "OPEN "
|
| 674 |
+
else:
|
| 675 |
+
visualizer.update([])
|
| 676 |
+
gate_str = "CLOSED"
|
| 677 |
+
|
| 678 |
+
print(f"[{now:6.2f}s] LVL={level:6.1f} dBFS diff={diff_db:+4.1f} | "
|
| 679 |
+
f"FLUXz={flux_z:4.2f} COH={coh:4.2f} | "
|
| 680 |
+
f"GATE={gate_str} | "
|
| 681 |
+
f"MODEL={t_model:5.1f}ms HIST={t_hist:5.1f}ms | "
|
| 682 |
+
f"DOA(R={Rmax:.2f}, n={peaks_count})", end="")
|
| 683 |
+
if peaks:
|
| 684 |
+
az_str = ", ".join([f"{p['azimuth_deg']:.0f}°" for p in peaks[:3]])
|
| 685 |
+
print(f" [{az_str}]")
|
| 686 |
+
else:
|
| 687 |
+
print()
|
| 688 |
+
|
| 689 |
+
# Slide buffer
|
| 690 |
+
audio_buffer[:, :-hop_samples] = audio_buffer[:, hop_samples:]
|
| 691 |
+
buffer_fill = window_samples - hop_samples
|
| 692 |
+
|
| 693 |
+
if n > remaining:
|
| 694 |
+
carry = min(n - remaining, hop_samples)
|
| 695 |
+
if carry > 0:
|
| 696 |
+
audio_buffer[:, buffer_fill:buffer_fill + carry] = audio_chunk[:, remaining:remaining + carry]
|
| 697 |
+
buffer_fill += carry
|
| 698 |
+
|
| 699 |
+
except KeyboardInterrupt:
|
| 700 |
+
print("\n[Streaming] Stopped by user.")
|
| 701 |
+
finally:
|
| 702 |
+
stream_closed = True
|
| 703 |
+
try:
|
| 704 |
+
stream.stop_stream()
|
| 705 |
+
stream.close()
|
| 706 |
+
except Exception:
|
| 707 |
+
pass
|
| 708 |
+
p.terminate()
|
| 709 |
+
plt.close('all')
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
def main():
|
| 713 |
+
parser = argparse.ArgumentParser(description="Stream ONNX DOA inference from microphone")
|
| 714 |
+
parser.add_argument('--onnx', type=str, required=False, help='Path to ONNX model file')
|
| 715 |
+
parser.add_argument('--config', type=str, default=None, help='Path to config.yaml')
|
| 716 |
+
parser.add_argument('--device-index', type=int, default=None, help='Audio device index')
|
| 717 |
+
parser.add_argument('--sample-rate', type=int, default=16000, help='Sample rate (Hz)')
|
| 718 |
+
parser.add_argument('--window-ms', type=int, default=200, help='Window length (ms)')
|
| 719 |
+
parser.add_argument('--hop-ms', type=int, default=100, help='Hop length (ms)')
|
| 720 |
+
parser.add_argument('--chunk-size', type=int, default=1600, help='Audio chunk size')
|
| 721 |
+
parser.add_argument('--cpu-only', action='store_true', help='Use CPU only')
|
| 722 |
+
parser.add_argument('--list-devices', action='store_true', help='List available audio devices')
|
| 723 |
+
|
| 724 |
+
# Histogram params
|
| 725 |
+
parser.add_argument('--K', type=int, default=72, help='Number of azimuth bins')
|
| 726 |
+
parser.add_argument('--tau', type=float, default=0.8, help='Softmax temperature')
|
| 727 |
+
parser.add_argument('--smooth-k', type=int, default=1, help='Smoothing kernel size')
|
| 728 |
+
parser.add_argument('--min-peak-height', type=float, default=0.10, help='Min peak height')
|
| 729 |
+
parser.add_argument('--min-window-mass', type=float, default=0.24, help='Min window mass')
|
| 730 |
+
parser.add_argument('--min-sep-deg', type=float, default=20.0, help='Min separation (deg)')
|
| 731 |
+
parser.add_argument('--min-active-ratio', type=float, default=0.20, help='Min active ratio')
|
| 732 |
+
parser.add_argument('--max-sources', type=int, default=3, help='Max sources')
|
| 733 |
+
|
| 734 |
+
# Event gate params
|
| 735 |
+
parser.add_argument('--level-delta-on-db', type=float, default=2.5, help='Level delta on (dB)')
|
| 736 |
+
parser.add_argument('--level-delta-off-db', type=float, default=1.0, help='Level delta off (dB)')
|
| 737 |
+
parser.add_argument('--level-min-dbfs', type=float, default=-60.0, help='Min level (dBFS)')
|
| 738 |
+
parser.add_argument('--level-ema-alpha', type=float, default=0.05, help='Level EMA alpha')
|
| 739 |
+
parser.add_argument('--event-hold-ms', type=int, default=300, help='Event hold (ms)')
|
| 740 |
+
parser.add_argument('--min-R-clip', type=float, default=0.18, help='Min R clip')
|
| 741 |
+
parser.add_argument('--event-refractory-ms', type=int, default=120, help='Event refractory (ms)')
|
| 742 |
+
|
| 743 |
+
# Onset params
|
| 744 |
+
parser.add_argument('--onset-alpha', type=float, default=0.05, help='Onset EMA alpha')
|
| 745 |
+
|
| 746 |
+
args = parser.parse_args()
|
| 747 |
+
|
| 748 |
+
if args.list_devices:
|
| 749 |
+
p = pyaudio.PyAudio()
|
| 750 |
+
print("\nAvailable audio input devices:")
|
| 751 |
+
print("-" * 80)
|
| 752 |
+
for i in range(p.get_device_count()):
|
| 753 |
+
info = p.get_device_info_by_index(i)
|
| 754 |
+
if info['maxInputChannels'] > 0:
|
| 755 |
+
print(f"Device {i}: {info['name']}")
|
| 756 |
+
print(f" Channels: {info['maxInputChannels']}, Sample Rate: {info['defaultSampleRate']:.0f} Hz\n")
|
| 757 |
+
p.terminate()
|
| 758 |
+
return
|
| 759 |
+
|
| 760 |
+
if args.onnx is None:
|
| 761 |
+
parser.error("--onnx is required (unless using --list-devices)")
|
| 762 |
+
|
| 763 |
+
onnx_path = Path(args.onnx)
|
| 764 |
+
if not onnx_path.exists():
|
| 765 |
+
parser.error(f"ONNX model not found: {onnx_path}")
|
| 766 |
+
|
| 767 |
+
stream_onnx_inference(
|
| 768 |
+
onnx_path=str(onnx_path),
|
| 769 |
+
config_path=args.config,
|
| 770 |
+
device_index=args.device_index,
|
| 771 |
+
sample_rate=args.sample_rate,
|
| 772 |
+
window_ms=args.window_ms,
|
| 773 |
+
hop_ms=args.hop_ms,
|
| 774 |
+
chunk_size=1600, # args.chunk_size,
|
| 775 |
+
cpu_only=args.cpu_only,
|
| 776 |
+
K=args.K,
|
| 777 |
+
tau=args.tau,
|
| 778 |
+
smooth_k=args.smooth_k,
|
| 779 |
+
min_peak_height=args.min_peak_height,
|
| 780 |
+
min_window_mass=args.min_window_mass,
|
| 781 |
+
min_sep_deg=args.min_sep_deg,
|
| 782 |
+
min_active_ratio=args.min_active_ratio,
|
| 783 |
+
max_sources=args.max_sources,
|
| 784 |
+
level_delta_on_db=args.level_delta_on_db,
|
| 785 |
+
level_delta_off_db=args.level_delta_off_db,
|
| 786 |
+
level_min_dbfs=args.level_min_dbfs,
|
| 787 |
+
level_ema_alpha=args.level_ema_alpha,
|
| 788 |
+
event_hold_ms=args.event_hold_ms,
|
| 789 |
+
min_R_clip=args.min_R_clip,
|
| 790 |
+
event_refractory_ms=args.event_refractory_ms,
|
| 791 |
+
onset_alpha=args.onset_alpha,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
if __name__ == "__main__":
|
| 796 |
+
main()
|
silero_vad.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a35ebf52fd3ce5f1469b2a36158dba761bc47b973ea3382b3186ca15b1f5af28
|
| 3 |
+
size 1807522
|