speech_direction_of_arrival / onnx_stream_microphone.py
EtMmohammedHafsati's picture
Upload folder using huggingface_hub
c2f1451 verified
#!/usr/bin/env python3
"""
Real-time DOA inference using ONNX model with microphone streaming.
Includes histogram-based detection, event gates, and onset detection.
"""
import sys
from pathlib import Path
# Add src directory to Python path
project_root = Path(__file__).parent
src_dir = project_root / "src"
if str(src_dir) not in sys.path:
sys.path.insert(0, str(src_dir))
import math
import numpy as np
import time
import queue
import argparse
import pyaudio
import onnxruntime as ort
import yaml
from typing import Optional, Dict, List, Tuple
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import torch
import torch.nn.functional as F
from mirokai_doa.features import stft_multi, compute_mag_phase_cos_sin
# -------------------------
# Math helpers (numpy version)
# -------------------------
def _angles_deg_np(K: int):
bin_size = 360.0 / K
deg = (np.arange(K, dtype=np.float32) + 0.5) * bin_size
rad = deg * np.pi / 180.0
return deg, np.cos(rad), np.sin(rad), bin_size
def _softmax_temp_np(logits: np.ndarray, tau: float = 0.8) -> np.ndarray:
exp_logits = np.exp((logits - np.max(logits, axis=-1, keepdims=True)) / tau)
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
def _circular_window_sum_np(row: np.ndarray, idx: int, half_w: int) -> float:
K = row.size
if half_w <= 0:
return float(row[idx])
acc = 0.0
for d in range(-half_w, half_w + 1):
acc += float(row[(idx + d) % K])
return acc
def _parabolic_peak_refine_np(row: np.ndarray, k: int) -> float:
K = row.size
km1, kp1 = (k - 1) % K, (k + 1) % K
y1, y2, y3 = float(row[km1]), float(row[k]), float(row[kp1])
denom = (y1 - 2 * y2 + y3)
if abs(denom) < 1e-9:
return 0.0
delta = 0.5 * (y1 - y3) / denom
return float(max(min(delta, 0.5), -0.5))
def _min_circ_separation_bins(a: int, chosen: List[int], K: int) -> int:
if not chosen:
return K
dmin = K
for j in chosen:
d = abs(a - j)
d = min(d, K - d)
dmin = min(dmin, d)
return dmin
# -------------------------
# Audio helpers
# -------------------------
def byte_to_float(data: bytes) -> np.ndarray:
samples = np.frombuffer(data, dtype=np.int16)
return samples.astype(np.float32) / 32768.0
def chunk_to_floatarray(data: bytes, channels: int) -> np.ndarray:
float_data = byte_to_float(data)
return float_data.reshape(-1, channels).T
def rms_dbfs(x: np.ndarray, eps: float = 1e-9) -> float:
val = np.sqrt((x * x).mean())
return 20.0 * np.log10(max(val, eps))
def frame_rms_energy(audio_buffer: np.ndarray, T: int) -> np.ndarray:
"""Split audio_buffer (C,N) into T equal segments; return per-frame RMS (normalized)."""
C, N = audio_buffer.shape
if T <= 0:
return np.ones(1, dtype=np.float32)
edges = np.linspace(0, N, T + 1, dtype=int)
e = []
for i in range(T):
seg = audio_buffer[:, edges[i]:edges[i+1]]
if seg.size == 0:
e.append(0.0)
else:
rms = np.sqrt((seg * seg).mean())
e.append(rms)
e = np.asarray(e, dtype=np.float32)
e = e / max(e.mean(), 1e-6)
return e
def spectral_flux_per_frame(audio_buffer: np.ndarray, T: int) -> np.ndarray:
"""Compute per-frame spectral flux across T segments from mono mix."""
C, N = audio_buffer.shape
if T <= 1:
return np.zeros((T,), dtype=np.float32)
mono = audio_buffer.mean(axis=0)
edges = np.linspace(0, N, T + 1, dtype=int)
mags = []
for i in range(T):
seg = mono[edges[i]:edges[i+1]]
if seg.size == 0:
mags.append(np.zeros(1, dtype=np.float32))
continue
win = np.hanning(len(seg)) if len(seg) > 8 else np.ones_like(seg)
S = np.fft.rfft(seg * win, n=len(seg))
mags.append(np.abs(S).astype(np.float32))
flux = np.zeros(T, dtype=np.float32)
for t in range(1, T):
a = mags[t-1]
b = mags[t]
L = min(len(a), len(b))
if L == 0:
flux[t] = 0.0
continue
diff = b[:L] - a[:L]
pos = np.maximum(diff, 0.0)
denom = np.sum(b[:L]) + 1e-6
flux[t] = float(np.sum(pos) / denom)
return flux
# -------------------------
# Onset detector
# -------------------------
class OnsetDetector:
def __init__(self, alpha: float = 0.05):
self.alpha = float(alpha)
self.mu = 0.0
self.var = 1.0
self.inited = False
def update_flux(self, flux_recent: float) -> float:
if not self.inited:
self.mu = flux_recent
self.var = 1e-3 + abs(flux_recent)
self.inited = True
delta = flux_recent - self.mu
self.mu += self.alpha * delta
self.var = (1 - self.alpha) * self.var + self.alpha * delta * delta
sigma = max(np.sqrt(self.var), 1e-6)
z = (flux_recent - self.mu) / sigma
return float(z)
@staticmethod
def last_segment_coherence(audio_buffer: np.ndarray, T: int,
pairs: List[Tuple[int,int]] = [(0,1),(0,2),(0,3)]) -> float:
C, N = audio_buffer.shape
if T < 1:
return 0.0
edges = np.linspace(0, N, T + 1, dtype=int)
s0, s1 = int(edges[-2]), int(edges[-1])
seg = audio_buffer[:, s0:s1]
if seg.shape[1] < 16:
return 0.0
rmax = 0.0
for (i,j) in pairs:
xi = seg[i] - seg[i].mean()
xj = seg[j] - seg[j].mean()
denom = (np.linalg.norm(xi) * np.linalg.norm(xj) + 1e-9)
r = float(np.dot(xi, xj) / denom)
rmax = max(rmax, abs(r))
return rmax
# -------------------------
# Histogram DOA detector (numpy/torch hybrid)
# -------------------------
class HistDOADetector:
def __init__(
self,
K: int = 72,
tau: float = 0.8,
gamma: float = 1.5,
smooth_k: int = 1,
window_bins: int = 1,
min_peak_height: float = 0.10,
min_window_mass: float = 0.24,
min_sep_deg: float = 20.0,
min_active_ratio: float = 0.20,
max_sources: int = 3,
device: str = "cpu",
):
self.K = int(K)
self.tau = float(tau)
self.gamma = float(gamma)
self.smooth_k = int(smooth_k)
self.window_bins = int(window_bins)
self.min_peak_height = float(min_peak_height)
self.min_window_mass = float(min_window_mass)
self.min_sep_deg = float(min_sep_deg)
self.min_active_ratio = float(min_active_ratio)
self.max_sources = int(max_sources)
self.device = torch.device(device)
self._deg, self._cos, self._sin, self._bin_size = self._angles_deg(self.K)
def _angles_deg(self, K: int):
bin_size = 360.0 / K
deg = torch.arange(K, device=self.device, dtype=torch.float32) + 0.5
deg = deg * bin_size
rad = deg * math.pi / 180.0
return deg, torch.cos(rad), torch.sin(rad), bin_size
def _aggregate_histogram(self, logits: np.ndarray, mask: np.ndarray) -> Tuple[np.ndarray, float, float]:
"""Aggregate histogram from logits and VAD mask."""
logits_t = torch.from_numpy(logits).float().to(self.device)
mask_t = torch.from_numpy(mask).float().to(self.device)
probs = F.softmax(logits_t / self.tau, dim=-1) # [T,K]
T = probs.shape[0]
m = mask_t
# Weighted histogram
x = torch.matmul(probs, self._cos)
y = torch.matmul(probs, self._sin)
R_t = torch.clamp(torch.sqrt(x * x + y * y), 0, 1)
w = m * (R_t ** self.gamma)
if w.sum() <= 0:
w = torch.ones_like(w) * 1e-6
hist = torch.matmul(w, probs)
hist = hist / hist.sum().clamp_min(1e-8)
if self.smooth_k > 0:
s = self.smooth_k
pad = torch.cat([hist[-s:], hist, hist[:s]], dim=0).view(1, 1, -1)
kernel = torch.ones(1, 1, 2 * s + 1, device=self.device) / (2 * s + 1)
hist = F.conv1d(pad, kernel, padding=0).view(-1)
X = torch.dot(hist, self._cos)
Y = torch.dot(hist, self._sin)
R_clip = float(torch.sqrt(X * X + Y * Y).item())
active_ratio = float(m.mean().item())
return hist.detach().cpu().numpy(), active_ratio, R_clip
def _pick_peaks(self, hist: np.ndarray) -> List[Dict[str, float]]:
"""Pick peaks from histogram."""
hist_t = torch.from_numpy(hist).float()
K = self.K
bin_size = self._bin_size
left = torch.roll(hist_t, 1, 0)
right = torch.roll(hist_t, -1, 0)
cand_idxs = ((hist_t > left) & (hist_t > right)).nonzero(as_tuple=False).flatten().tolist()
cand_idxs.sort(key=lambda i: float(hist_t[i].item()), reverse=True)
chosen, out = [], []
min_sep_bins = max(1, int(round(self.min_sep_deg / bin_size)))
for idx in cand_idxs:
if _min_circ_separation_bins(idx, chosen, K) < min_sep_bins:
continue
if float(hist_t[idx].item()) < self.min_peak_height:
continue
mass = _circular_window_sum_np(hist, idx, self.window_bins)
if mass < self.min_window_mass:
continue
delta = _parabolic_peak_refine_np(hist, idx)
angle_deg = ((idx + 0.5 + delta) * bin_size) % 360.0
out.append({"azimuth_deg": angle_deg, "score": float(mass)})
chosen.append(idx)
if len(out) >= self.max_sources:
break
return out
def detect(self, logits: np.ndarray) -> Dict[str, any]:
"""Detect DOA from logits (no VAD separation)."""
# Use all frames (no VAD masking)
mask = np.ones(logits.shape[0], dtype=np.float32)
hist, active_ratio, R_clip = self._aggregate_histogram(logits, mask)
peaks = self._pick_peaks(hist) if active_ratio >= self.min_active_ratio else []
bins_deg = (np.arange(self.K) + 0.5) * (360.0 / self.K)
return {
"peaks": peaks,
"active_ratio": active_ratio,
"R_clip": R_clip,
"hist": hist,
"bins_deg": bins_deg,
"has_event": bool(peaks),
}
# -------------------------
# Event gate
# -------------------------
class LevelChangeGate:
def __init__(
self,
delta_on_db: float = 2.5,
delta_off_db: float = 1.0,
level_min_dbfs: float = -60.0,
ema_alpha: float = 0.05,
min_R_clip: float = 0.18,
hold_ms: int = 300,
refractory_ms: int = 120
):
self.delta_on_db = float(delta_on_db)
self.delta_off_db = float(delta_off_db)
self.level_min_dbfs = float(level_min_dbfs)
self.ema_alpha = float(ema_alpha)
self.min_R_clip = float(min_R_clip)
self.hold_s = float(hold_ms) / 1000.0
self.refractory_s = float(refractory_ms) / 1000.0
self.bg_dbfs = None
self.active = False
self.last_change_time = 0.0
def update(self, level_dbfs: float, now_s: float,
peaks_count: int, R_clip_max: float):
if self.bg_dbfs is None:
self.bg_dbfs = level_dbfs
diff_db = level_dbfs - self.bg_dbfs
want_open = (
(now_s - self.last_change_time) >= self.refractory_s and
((level_dbfs > self.level_min_dbfs and diff_db >= self.delta_on_db) or
(peaks_count > 0 and R_clip_max >= self.min_R_clip))
)
if not self.active:
if want_open:
self.active = True
self.last_change_time = now_s
else:
if (now_s - self.last_change_time) >= self.hold_s:
want_close = (
(diff_db <= self.delta_off_db) and
(peaks_count == 0 or R_clip_max < self.min_R_clip)
)
if want_close:
self.active = False
self.last_change_time = now_s
self.bg_dbfs = (1.0 - self.ema_alpha) * self.bg_dbfs + self.ema_alpha * level_dbfs
return self.active, diff_db
# -------------------------
# ONNX Inference
# -------------------------
class ONNXDOAStreaming:
def __init__(
self,
onnx_path: str,
config_path: Optional[str] = None,
providers: Optional[list] = None
):
if config_path is None:
config_path = project_root / "configs" / "train.yaml"
with open(config_path, 'r') as f:
self.config = yaml.safe_load(f)
self.features_cfg = self.config.get('features', {})
self.sr = self.features_cfg.get('sr', 16000)
self.win_s = self.features_cfg.get('win_s', 0.032)
self.hop_s = self.features_cfg.get('hop_s', 0.010)
self.nfft = self.features_cfg.get('nfft', 1024)
self.K = self.features_cfg.get('K', 72)
if providers is None:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self.session = ort.InferenceSession(onnx_path, sess_options=sess_options, providers=providers)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
input_shape = self.session.get_inputs()[0].shape
self.is_doa_model = input_shape[-1] == 513 if len(input_shape) == 4 else False
print(f"ONNX Model loaded: {onnx_path}")
print(f" Input shape: {input_shape}")
print(f" Model type: {'DoAEstimator' if self.is_doa_model else 'TFPoolClassifierNoCond'}")
print(f" Providers: {self.session.get_providers()}")
def compute_features(self, mixture: np.ndarray) -> np.ndarray:
if mixture.ndim == 1:
raise ValueError("Mixture must be multichannel (4 channels)")
if mixture.shape[0] != 4 and mixture.shape[1] == 4:
mixture = mixture.T
if mixture.shape[0] != 4:
raise ValueError(f"Expected 4 channels, got {mixture.shape[0]}")
x4 = mixture.astype(np.float32)
X, freqs, times = stft_multi(x4.T, fs=self.sr, win_s=self.win_s, hop_s=self.hop_s,
nfft=self.nfft, window="hann", center=True, pad_mode="reflect")
feats = compute_mag_phase_cos_sin(X, dtype=np.float32)
return feats
def inference_batch(self, feats: np.ndarray, batch_size: int = 25) -> np.ndarray:
T_frames, C_feat, F = feats.shape
assert C_feat == 12, f"Expected 12 feature channels, got {C_feat}"
all_logits = []
for start_idx in range(0, T_frames, batch_size):
end_idx = min(start_idx + batch_size, T_frames)
batch_feats = feats[start_idx:end_idx]
batch_T = batch_feats.shape[0]
if batch_T < batch_size:
padding = np.zeros((batch_size - batch_T, C_feat, F), dtype=batch_feats.dtype)
batch_feats = np.concatenate([batch_feats, padding], axis=0)
feats_tensor = batch_feats.transpose(1, 0, 2)[np.newaxis, ...]
outputs = self.session.run([self.output_name], {self.input_name: feats_tensor.astype(np.float32)})
batch_logits = outputs[0]
if batch_logits.ndim == 2:
if batch_logits.shape[0] == 1 and batch_logits.shape[1] == self.K:
batch_logits = np.tile(batch_logits, (batch_T, 1))
elif batch_logits.shape[0] == 1:
batch_logits = batch_logits[0]
else:
batch_logits = batch_logits[:batch_T]
elif batch_logits.ndim == 3:
batch_logits = batch_logits[0, :batch_T]
all_logits.append(batch_logits)
return np.concatenate(all_logits, axis=0)
# -------------------------
# Visualization
# -------------------------
class CurrentLineVisualizer:
def __init__(self, title: str = "Current DOA"):
self.fig = plt.figure(figsize=(7.5, 7.5))
self.ax = self.fig.add_subplot(111, projection='polar')
self._setup_axes(title)
plt.ion()
plt.show(block=False)
def _setup_axes(self, title: str):
self.ax.clear()
self.ax.set_title(title, fontsize=13, fontweight='bold', pad=16)
self.ax.set_theta_zero_location('N')
self.ax.set_theta_direction(-1)
self.ax.set_thetalim(0, 2*np.pi)
self.ax.set_ylim(0, 1.05)
self.ax.set_yticklabels([])
self.ax.add_patch(Circle((0, 0), 1.0, fill=False, color='gray', linestyle='--', linewidth=1, alpha=0.5))
self.ax.grid(alpha=0.2)
def update(self, peaks: List[Dict]):
self._setup_axes("Current DOA")
for pk in peaks[:3]:
az = float(pk["azimuth_deg"])
sc = float(pk.get("score", 0.2))
lw = 2.0 + 5.0 * float(np.clip(sc, 0.0, 0.6))
theta = np.deg2rad(az)
self.ax.plot([theta, theta], [0.0, 1.0], color='tab:green', linewidth=lw, solid_capstyle='round')
self.ax.text(theta, 1.02, f"{az:.0f}°", ha='center', va='bottom', fontsize=10,
color='tab:green', fontweight='bold')
self.fig.canvas.draw_idle()
self.fig.canvas.flush_events()
# -------------------------
# Main streaming function
# -------------------------
def stream_onnx_inference(
onnx_path: str,
config_path: Optional[str] = None,
device_index: Optional[int] = None,
sample_rate: int = 16000,
window_ms: int = 200,
hop_ms: int = 100,
chunk_size: int = 1600,
cpu_only: bool = False,
# Histogram params
K: int = 72,
tau: float = 0.8,
smooth_k: int = 1,
min_peak_height: float = 0.10,
min_window_mass: float = 0.24,
min_sep_deg: float = 20.0,
min_active_ratio: float = 0.20,
max_sources: int = 3,
# Event gate params
level_delta_on_db: float = 2.5,
level_delta_off_db: float = 1.0,
level_min_dbfs: float = -60.0,
level_ema_alpha: float = 0.05,
event_hold_ms: int = 300,
min_R_clip: float = 0.18,
event_refractory_ms: int = 120,
# Onset params
onset_alpha: float = 0.05,
):
"""Stream inference from microphone using ONNX model."""
providers = ['CPUExecutionProvider'] if cpu_only else ['CUDAExecutionProvider', 'CPUExecutionProvider']
infer = ONNXDOAStreaming(onnx_path, config_path, providers=providers)
# Override K if provided
if K != infer.K:
print(f"Warning: K mismatch. Model K={infer.K}, requested K={K}. Using model K.")
K = infer.K
det = HistDOADetector(
K=K, tau=tau, gamma=1.5, smooth_k=smooth_k,
window_bins=1, min_peak_height=min_peak_height, min_window_mass=min_window_mass,
min_sep_deg=min_sep_deg, min_active_ratio=min_active_ratio, max_sources=max_sources,
device="cuda" if not cpu_only and torch.cuda.is_available() else "cpu"
)
gate = LevelChangeGate(
delta_on_db=level_delta_on_db, delta_off_db=level_delta_off_db,
level_min_dbfs=level_min_dbfs, ema_alpha=level_ema_alpha,
min_R_clip=min_R_clip,
hold_ms=event_hold_ms, refractory_ms=event_refractory_ms
)
onset = OnsetDetector(alpha=onset_alpha)
visualizer = CurrentLineVisualizer()
window_samples = int(sample_rate * window_ms / 1000)
hop_samples = int(sample_rate * hop_ms / 1000)
p = pyaudio.PyAudio()
if device_index is None:
for i in range(p.get_device_count()):
info = p.get_device_info_by_index(i)
name = info['name'].lower()
# Check by name first (PulseAudio might hide channels)
if 'respeaker' in name or 'seeed' in name or '2886' in name:
device_index = i
print(f"Auto-detected ReSpeaker at device {i}: {info['name']}")
break
if device_index is None:
print("\n[Audio] Could not auto-detect Respeaker. Use --device-index or --list-devices.\n")
p.terminate()
return
# Check device info
device_info = p.get_device_info_by_index(device_index)
print(f"Device info: {device_info['name']}")
print(f" Max input channels: {device_info['maxInputChannels']}")
print(f" Default sample rate: {device_info['defaultSampleRate']:.0f} Hz")
# If device shows 0 channels, it's likely managed by PulseAudio
# We'll still try to open it - sometimes it works despite the report
if device_info['maxInputChannels'] == 0:
print(" Warning: Device reports 0 channels (may be managed by PulseAudio)")
print(" Attempting to open anyway...")
CHANNELS = 6
RAW_CHANNELS = [1, 4, 3, 2] # your requested order
FORMAT = pyaudio.paInt16
audio_buffer = np.zeros((4, window_samples), dtype=np.float32)
buffer_fill = 0
start_time = time.time()
audio_queue = queue.Queue()
stream_closed = False
def _fill_buffer(in_data, frame_count, time_info, status_flags):
if not stream_closed:
audio_queue.put(in_data)
return None, pyaudio.paContinue
try:
# Try to open the stream - PyAudio will validate channels
stream = p.open(
format=FORMAT,
channels=CHANNELS,
rate=sample_rate,
input=True,
input_device_index=device_index,
frames_per_buffer=chunk_size,
stream_callback=_fill_buffer
)
print(" Successfully opened audio stream with 6 channels")
except Exception as e:
print(f"\n[Audio] Could not open input device (index {device_index}).")
print(f" Error: {e}")
print("\n The ReSpeaker device is likely locked by PulseAudio.")
print(" Solutions:")
print(" 1. Temporarily stop PulseAudio: pulseaudio --kill")
print(" 2. Then restart it after: pulseaudio --start")
print(" 3. Or configure PulseAudio to allow direct ALSA access\n")
p.terminate()
return
stream.start_stream()
print(f"\n[Streaming] Started. Window: {window_ms}ms, Hop: {hop_ms}ms")
print(" Press Ctrl+C to stop.\n")
try:
while True:
try:
data = audio_queue.get(timeout=1.0)
except queue.Empty:
continue
chunk_all = chunk_to_floatarray(data, CHANNELS) # (6, N)
audio_chunk = chunk_all[RAW_CHANNELS, :] # (4, N)
n = audio_chunk.shape[1]
if buffer_fill + n <= window_samples:
audio_buffer[:, buffer_fill:buffer_fill + n] = audio_chunk
buffer_fill += n
continue
remaining = window_samples - buffer_fill
if remaining > 0:
audio_buffer[:, buffer_fill:] = audio_chunk[:, :remaining]
buffer_fill = window_samples
# Inference
t0 = time.perf_counter()
feats = infer.compute_features(audio_buffer)
logits = infer.inference_batch(feats)
t_model = (time.perf_counter() - t0) * 1000.0
T = logits.shape[0]
energies = frame_rms_energy(audio_buffer, T)
flux = spectral_flux_per_frame(audio_buffer, T)
flux_recent = float(max(flux[-1], flux[-2] if T >= 2 else 0.0))
flux_z = onset.update_flux(flux_recent)
coh = OnsetDetector.last_segment_coherence(audio_buffer, T)
# DOA detection (no VAD)
t1 = time.perf_counter()
det_result = det.detect(logits)
t_hist = (time.perf_counter() - t1) * 1000.0
peaks = det_result["peaks"]
peaks_count = len(peaks)
Rmax = det_result["R_clip"]
level = rms_dbfs(audio_buffer)
now = time.time() - start_time
gate_open, diff_db = gate.update(level_dbfs=level, now_s=now,
peaks_count=peaks_count, R_clip_max=Rmax)
if gate_open:
visualizer.update(peaks)
gate_str = "OPEN "
else:
visualizer.update([])
gate_str = "CLOSED"
print(f"[{now:6.2f}s] LVL={level:6.1f} dBFS diff={diff_db:+4.1f} | "
f"FLUXz={flux_z:4.2f} COH={coh:4.2f} | "
f"GATE={gate_str} | "
f"MODEL={t_model:5.1f}ms HIST={t_hist:5.1f}ms | "
f"DOA(R={Rmax:.2f}, n={peaks_count})", end="")
if peaks:
az_str = ", ".join([f"{p['azimuth_deg']:.0f}°" for p in peaks[:3]])
print(f" [{az_str}]")
else:
print()
# Slide buffer
audio_buffer[:, :-hop_samples] = audio_buffer[:, hop_samples:]
buffer_fill = window_samples - hop_samples
if n > remaining:
carry = min(n - remaining, hop_samples)
if carry > 0:
audio_buffer[:, buffer_fill:buffer_fill + carry] = audio_chunk[:, remaining:remaining + carry]
buffer_fill += carry
except KeyboardInterrupt:
print("\n[Streaming] Stopped by user.")
finally:
stream_closed = True
try:
stream.stop_stream()
stream.close()
except Exception:
pass
p.terminate()
plt.close('all')
def main():
parser = argparse.ArgumentParser(description="Stream ONNX DOA inference from microphone")
parser.add_argument('--onnx', type=str, required=False, help='Path to ONNX model file')
parser.add_argument('--config', type=str, default=None, help='Path to config.yaml')
parser.add_argument('--device-index', type=int, default=None, help='Audio device index')
parser.add_argument('--sample-rate', type=int, default=16000, help='Sample rate (Hz)')
parser.add_argument('--window-ms', type=int, default=200, help='Window length (ms)')
parser.add_argument('--hop-ms', type=int, default=100, help='Hop length (ms)')
parser.add_argument('--chunk-size', type=int, default=1600, help='Audio chunk size')
parser.add_argument('--cpu-only', action='store_true', help='Use CPU only')
parser.add_argument('--list-devices', action='store_true', help='List available audio devices')
# Histogram params
parser.add_argument('--K', type=int, default=72, help='Number of azimuth bins')
parser.add_argument('--tau', type=float, default=0.8, help='Softmax temperature')
parser.add_argument('--smooth-k', type=int, default=1, help='Smoothing kernel size')
parser.add_argument('--min-peak-height', type=float, default=0.10, help='Min peak height')
parser.add_argument('--min-window-mass', type=float, default=0.24, help='Min window mass')
parser.add_argument('--min-sep-deg', type=float, default=20.0, help='Min separation (deg)')
parser.add_argument('--min-active-ratio', type=float, default=0.20, help='Min active ratio')
parser.add_argument('--max-sources', type=int, default=3, help='Max sources')
# Event gate params
parser.add_argument('--level-delta-on-db', type=float, default=2.5, help='Level delta on (dB)')
parser.add_argument('--level-delta-off-db', type=float, default=1.0, help='Level delta off (dB)')
parser.add_argument('--level-min-dbfs', type=float, default=-60.0, help='Min level (dBFS)')
parser.add_argument('--level-ema-alpha', type=float, default=0.05, help='Level EMA alpha')
parser.add_argument('--event-hold-ms', type=int, default=300, help='Event hold (ms)')
parser.add_argument('--min-R-clip', type=float, default=0.18, help='Min R clip')
parser.add_argument('--event-refractory-ms', type=int, default=120, help='Event refractory (ms)')
# Onset params
parser.add_argument('--onset-alpha', type=float, default=0.05, help='Onset EMA alpha')
args = parser.parse_args()
if args.list_devices:
p = pyaudio.PyAudio()
print("\nAvailable audio input devices:")
print("-" * 80)
for i in range(p.get_device_count()):
info = p.get_device_info_by_index(i)
if info['maxInputChannels'] > 0:
print(f"Device {i}: {info['name']}")
print(f" Channels: {info['maxInputChannels']}, Sample Rate: {info['defaultSampleRate']:.0f} Hz\n")
p.terminate()
return
if args.onnx is None:
parser.error("--onnx is required (unless using --list-devices)")
onnx_path = Path(args.onnx)
if not onnx_path.exists():
parser.error(f"ONNX model not found: {onnx_path}")
stream_onnx_inference(
onnx_path=str(onnx_path),
config_path=args.config,
device_index=args.device_index,
sample_rate=args.sample_rate,
window_ms=args.window_ms,
hop_ms=args.hop_ms,
chunk_size=1600, # args.chunk_size,
cpu_only=args.cpu_only,
K=args.K,
tau=args.tau,
smooth_k=args.smooth_k,
min_peak_height=args.min_peak_height,
min_window_mass=args.min_window_mass,
min_sep_deg=args.min_sep_deg,
min_active_ratio=args.min_active_ratio,
max_sources=args.max_sources,
level_delta_on_db=args.level_delta_on_db,
level_delta_off_db=args.level_delta_off_db,
level_min_dbfs=args.level_min_dbfs,
level_ema_alpha=args.level_ema_alpha,
event_hold_ms=args.event_hold_ms,
min_R_clip=args.min_R_clip,
event_refractory_ms=args.event_refractory_ms,
onset_alpha=args.onset_alpha,
)
if __name__ == "__main__":
main()