#!/usr/bin/env python3 """ Real-time DOA inference using ONNX model with microphone streaming. Includes histogram-based detection, event gates, and onset detection. """ import sys from pathlib import Path # Add src directory to Python path project_root = Path(__file__).parent src_dir = project_root / "src" if str(src_dir) not in sys.path: sys.path.insert(0, str(src_dir)) import math import numpy as np import time import queue import argparse import pyaudio import onnxruntime as ort import yaml from typing import Optional, Dict, List, Tuple import matplotlib.pyplot as plt from matplotlib.patches import Circle import torch import torch.nn.functional as F from mirokai_doa.features import stft_multi, compute_mag_phase_cos_sin # ------------------------- # Math helpers (numpy version) # ------------------------- def _angles_deg_np(K: int): bin_size = 360.0 / K deg = (np.arange(K, dtype=np.float32) + 0.5) * bin_size rad = deg * np.pi / 180.0 return deg, np.cos(rad), np.sin(rad), bin_size def _softmax_temp_np(logits: np.ndarray, tau: float = 0.8) -> np.ndarray: exp_logits = np.exp((logits - np.max(logits, axis=-1, keepdims=True)) / tau) return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) def _circular_window_sum_np(row: np.ndarray, idx: int, half_w: int) -> float: K = row.size if half_w <= 0: return float(row[idx]) acc = 0.0 for d in range(-half_w, half_w + 1): acc += float(row[(idx + d) % K]) return acc def _parabolic_peak_refine_np(row: np.ndarray, k: int) -> float: K = row.size km1, kp1 = (k - 1) % K, (k + 1) % K y1, y2, y3 = float(row[km1]), float(row[k]), float(row[kp1]) denom = (y1 - 2 * y2 + y3) if abs(denom) < 1e-9: return 0.0 delta = 0.5 * (y1 - y3) / denom return float(max(min(delta, 0.5), -0.5)) def _min_circ_separation_bins(a: int, chosen: List[int], K: int) -> int: if not chosen: return K dmin = K for j in chosen: d = abs(a - j) d = min(d, K - d) dmin = min(dmin, d) return dmin # ------------------------- # Audio helpers # ------------------------- def byte_to_float(data: bytes) -> np.ndarray: samples = np.frombuffer(data, dtype=np.int16) return samples.astype(np.float32) / 32768.0 def chunk_to_floatarray(data: bytes, channels: int) -> np.ndarray: float_data = byte_to_float(data) return float_data.reshape(-1, channels).T def rms_dbfs(x: np.ndarray, eps: float = 1e-9) -> float: val = np.sqrt((x * x).mean()) return 20.0 * np.log10(max(val, eps)) def frame_rms_energy(audio_buffer: np.ndarray, T: int) -> np.ndarray: """Split audio_buffer (C,N) into T equal segments; return per-frame RMS (normalized).""" C, N = audio_buffer.shape if T <= 0: return np.ones(1, dtype=np.float32) edges = np.linspace(0, N, T + 1, dtype=int) e = [] for i in range(T): seg = audio_buffer[:, edges[i]:edges[i+1]] if seg.size == 0: e.append(0.0) else: rms = np.sqrt((seg * seg).mean()) e.append(rms) e = np.asarray(e, dtype=np.float32) e = e / max(e.mean(), 1e-6) return e def spectral_flux_per_frame(audio_buffer: np.ndarray, T: int) -> np.ndarray: """Compute per-frame spectral flux across T segments from mono mix.""" C, N = audio_buffer.shape if T <= 1: return np.zeros((T,), dtype=np.float32) mono = audio_buffer.mean(axis=0) edges = np.linspace(0, N, T + 1, dtype=int) mags = [] for i in range(T): seg = mono[edges[i]:edges[i+1]] if seg.size == 0: mags.append(np.zeros(1, dtype=np.float32)) continue win = np.hanning(len(seg)) if len(seg) > 8 else np.ones_like(seg) S = np.fft.rfft(seg * win, n=len(seg)) mags.append(np.abs(S).astype(np.float32)) flux = np.zeros(T, dtype=np.float32) for t in range(1, T): a = mags[t-1] b = mags[t] L = min(len(a), len(b)) if L == 0: flux[t] = 0.0 continue diff = b[:L] - a[:L] pos = np.maximum(diff, 0.0) denom = np.sum(b[:L]) + 1e-6 flux[t] = float(np.sum(pos) / denom) return flux # ------------------------- # Onset detector # ------------------------- class OnsetDetector: def __init__(self, alpha: float = 0.05): self.alpha = float(alpha) self.mu = 0.0 self.var = 1.0 self.inited = False def update_flux(self, flux_recent: float) -> float: if not self.inited: self.mu = flux_recent self.var = 1e-3 + abs(flux_recent) self.inited = True delta = flux_recent - self.mu self.mu += self.alpha * delta self.var = (1 - self.alpha) * self.var + self.alpha * delta * delta sigma = max(np.sqrt(self.var), 1e-6) z = (flux_recent - self.mu) / sigma return float(z) @staticmethod def last_segment_coherence(audio_buffer: np.ndarray, T: int, pairs: List[Tuple[int,int]] = [(0,1),(0,2),(0,3)]) -> float: C, N = audio_buffer.shape if T < 1: return 0.0 edges = np.linspace(0, N, T + 1, dtype=int) s0, s1 = int(edges[-2]), int(edges[-1]) seg = audio_buffer[:, s0:s1] if seg.shape[1] < 16: return 0.0 rmax = 0.0 for (i,j) in pairs: xi = seg[i] - seg[i].mean() xj = seg[j] - seg[j].mean() denom = (np.linalg.norm(xi) * np.linalg.norm(xj) + 1e-9) r = float(np.dot(xi, xj) / denom) rmax = max(rmax, abs(r)) return rmax # ------------------------- # Histogram DOA detector (numpy/torch hybrid) # ------------------------- class HistDOADetector: def __init__( self, K: int = 72, tau: float = 0.8, gamma: float = 1.5, smooth_k: int = 1, window_bins: int = 1, min_peak_height: float = 0.10, min_window_mass: float = 0.24, min_sep_deg: float = 20.0, min_active_ratio: float = 0.20, max_sources: int = 3, device: str = "cpu", ): self.K = int(K) self.tau = float(tau) self.gamma = float(gamma) self.smooth_k = int(smooth_k) self.window_bins = int(window_bins) self.min_peak_height = float(min_peak_height) self.min_window_mass = float(min_window_mass) self.min_sep_deg = float(min_sep_deg) self.min_active_ratio = float(min_active_ratio) self.max_sources = int(max_sources) self.device = torch.device(device) self._deg, self._cos, self._sin, self._bin_size = self._angles_deg(self.K) def _angles_deg(self, K: int): bin_size = 360.0 / K deg = torch.arange(K, device=self.device, dtype=torch.float32) + 0.5 deg = deg * bin_size rad = deg * math.pi / 180.0 return deg, torch.cos(rad), torch.sin(rad), bin_size def _aggregate_histogram(self, logits: np.ndarray, mask: np.ndarray) -> Tuple[np.ndarray, float, float]: """Aggregate histogram from logits and VAD mask.""" logits_t = torch.from_numpy(logits).float().to(self.device) mask_t = torch.from_numpy(mask).float().to(self.device) probs = F.softmax(logits_t / self.tau, dim=-1) # [T,K] T = probs.shape[0] m = mask_t # Weighted histogram x = torch.matmul(probs, self._cos) y = torch.matmul(probs, self._sin) R_t = torch.clamp(torch.sqrt(x * x + y * y), 0, 1) w = m * (R_t ** self.gamma) if w.sum() <= 0: w = torch.ones_like(w) * 1e-6 hist = torch.matmul(w, probs) hist = hist / hist.sum().clamp_min(1e-8) if self.smooth_k > 0: s = self.smooth_k pad = torch.cat([hist[-s:], hist, hist[:s]], dim=0).view(1, 1, -1) kernel = torch.ones(1, 1, 2 * s + 1, device=self.device) / (2 * s + 1) hist = F.conv1d(pad, kernel, padding=0).view(-1) X = torch.dot(hist, self._cos) Y = torch.dot(hist, self._sin) R_clip = float(torch.sqrt(X * X + Y * Y).item()) active_ratio = float(m.mean().item()) return hist.detach().cpu().numpy(), active_ratio, R_clip def _pick_peaks(self, hist: np.ndarray) -> List[Dict[str, float]]: """Pick peaks from histogram.""" hist_t = torch.from_numpy(hist).float() K = self.K bin_size = self._bin_size left = torch.roll(hist_t, 1, 0) right = torch.roll(hist_t, -1, 0) cand_idxs = ((hist_t > left) & (hist_t > right)).nonzero(as_tuple=False).flatten().tolist() cand_idxs.sort(key=lambda i: float(hist_t[i].item()), reverse=True) chosen, out = [], [] min_sep_bins = max(1, int(round(self.min_sep_deg / bin_size))) for idx in cand_idxs: if _min_circ_separation_bins(idx, chosen, K) < min_sep_bins: continue if float(hist_t[idx].item()) < self.min_peak_height: continue mass = _circular_window_sum_np(hist, idx, self.window_bins) if mass < self.min_window_mass: continue delta = _parabolic_peak_refine_np(hist, idx) angle_deg = ((idx + 0.5 + delta) * bin_size) % 360.0 out.append({"azimuth_deg": angle_deg, "score": float(mass)}) chosen.append(idx) if len(out) >= self.max_sources: break return out def detect(self, logits: np.ndarray) -> Dict[str, any]: """Detect DOA from logits (no VAD separation).""" # Use all frames (no VAD masking) mask = np.ones(logits.shape[0], dtype=np.float32) hist, active_ratio, R_clip = self._aggregate_histogram(logits, mask) peaks = self._pick_peaks(hist) if active_ratio >= self.min_active_ratio else [] bins_deg = (np.arange(self.K) + 0.5) * (360.0 / self.K) return { "peaks": peaks, "active_ratio": active_ratio, "R_clip": R_clip, "hist": hist, "bins_deg": bins_deg, "has_event": bool(peaks), } # ------------------------- # Event gate # ------------------------- class LevelChangeGate: def __init__( self, delta_on_db: float = 2.5, delta_off_db: float = 1.0, level_min_dbfs: float = -60.0, ema_alpha: float = 0.05, min_R_clip: float = 0.18, hold_ms: int = 300, refractory_ms: int = 120 ): self.delta_on_db = float(delta_on_db) self.delta_off_db = float(delta_off_db) self.level_min_dbfs = float(level_min_dbfs) self.ema_alpha = float(ema_alpha) self.min_R_clip = float(min_R_clip) self.hold_s = float(hold_ms) / 1000.0 self.refractory_s = float(refractory_ms) / 1000.0 self.bg_dbfs = None self.active = False self.last_change_time = 0.0 def update(self, level_dbfs: float, now_s: float, peaks_count: int, R_clip_max: float): if self.bg_dbfs is None: self.bg_dbfs = level_dbfs diff_db = level_dbfs - self.bg_dbfs want_open = ( (now_s - self.last_change_time) >= self.refractory_s and ((level_dbfs > self.level_min_dbfs and diff_db >= self.delta_on_db) or (peaks_count > 0 and R_clip_max >= self.min_R_clip)) ) if not self.active: if want_open: self.active = True self.last_change_time = now_s else: if (now_s - self.last_change_time) >= self.hold_s: want_close = ( (diff_db <= self.delta_off_db) and (peaks_count == 0 or R_clip_max < self.min_R_clip) ) if want_close: self.active = False self.last_change_time = now_s self.bg_dbfs = (1.0 - self.ema_alpha) * self.bg_dbfs + self.ema_alpha * level_dbfs return self.active, diff_db # ------------------------- # ONNX Inference # ------------------------- class ONNXDOAStreaming: def __init__( self, onnx_path: str, config_path: Optional[str] = None, providers: Optional[list] = None ): if config_path is None: config_path = project_root / "configs" / "train.yaml" with open(config_path, 'r') as f: self.config = yaml.safe_load(f) self.features_cfg = self.config.get('features', {}) self.sr = self.features_cfg.get('sr', 16000) self.win_s = self.features_cfg.get('win_s', 0.032) self.hop_s = self.features_cfg.get('hop_s', 0.010) self.nfft = self.features_cfg.get('nfft', 1024) self.K = self.features_cfg.get('K', 72) if providers is None: providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL self.session = ort.InferenceSession(onnx_path, sess_options=sess_options, providers=providers) self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name input_shape = self.session.get_inputs()[0].shape self.is_doa_model = input_shape[-1] == 513 if len(input_shape) == 4 else False print(f"ONNX Model loaded: {onnx_path}") print(f" Input shape: {input_shape}") print(f" Model type: {'DoAEstimator' if self.is_doa_model else 'TFPoolClassifierNoCond'}") print(f" Providers: {self.session.get_providers()}") def compute_features(self, mixture: np.ndarray) -> np.ndarray: if mixture.ndim == 1: raise ValueError("Mixture must be multichannel (4 channels)") if mixture.shape[0] != 4 and mixture.shape[1] == 4: mixture = mixture.T if mixture.shape[0] != 4: raise ValueError(f"Expected 4 channels, got {mixture.shape[0]}") x4 = mixture.astype(np.float32) X, freqs, times = stft_multi(x4.T, fs=self.sr, win_s=self.win_s, hop_s=self.hop_s, nfft=self.nfft, window="hann", center=True, pad_mode="reflect") feats = compute_mag_phase_cos_sin(X, dtype=np.float32) return feats def inference_batch(self, feats: np.ndarray, batch_size: int = 25) -> np.ndarray: T_frames, C_feat, F = feats.shape assert C_feat == 12, f"Expected 12 feature channels, got {C_feat}" all_logits = [] for start_idx in range(0, T_frames, batch_size): end_idx = min(start_idx + batch_size, T_frames) batch_feats = feats[start_idx:end_idx] batch_T = batch_feats.shape[0] if batch_T < batch_size: padding = np.zeros((batch_size - batch_T, C_feat, F), dtype=batch_feats.dtype) batch_feats = np.concatenate([batch_feats, padding], axis=0) feats_tensor = batch_feats.transpose(1, 0, 2)[np.newaxis, ...] outputs = self.session.run([self.output_name], {self.input_name: feats_tensor.astype(np.float32)}) batch_logits = outputs[0] if batch_logits.ndim == 2: if batch_logits.shape[0] == 1 and batch_logits.shape[1] == self.K: batch_logits = np.tile(batch_logits, (batch_T, 1)) elif batch_logits.shape[0] == 1: batch_logits = batch_logits[0] else: batch_logits = batch_logits[:batch_T] elif batch_logits.ndim == 3: batch_logits = batch_logits[0, :batch_T] all_logits.append(batch_logits) return np.concatenate(all_logits, axis=0) # ------------------------- # Visualization # ------------------------- class CurrentLineVisualizer: def __init__(self, title: str = "Current DOA"): self.fig = plt.figure(figsize=(7.5, 7.5)) self.ax = self.fig.add_subplot(111, projection='polar') self._setup_axes(title) plt.ion() plt.show(block=False) def _setup_axes(self, title: str): self.ax.clear() self.ax.set_title(title, fontsize=13, fontweight='bold', pad=16) self.ax.set_theta_zero_location('N') self.ax.set_theta_direction(-1) self.ax.set_thetalim(0, 2*np.pi) self.ax.set_ylim(0, 1.05) self.ax.set_yticklabels([]) self.ax.add_patch(Circle((0, 0), 1.0, fill=False, color='gray', linestyle='--', linewidth=1, alpha=0.5)) self.ax.grid(alpha=0.2) def update(self, peaks: List[Dict]): self._setup_axes("Current DOA") for pk in peaks[:3]: az = float(pk["azimuth_deg"]) sc = float(pk.get("score", 0.2)) lw = 2.0 + 5.0 * float(np.clip(sc, 0.0, 0.6)) theta = np.deg2rad(az) self.ax.plot([theta, theta], [0.0, 1.0], color='tab:green', linewidth=lw, solid_capstyle='round') self.ax.text(theta, 1.02, f"{az:.0f}°", ha='center', va='bottom', fontsize=10, color='tab:green', fontweight='bold') self.fig.canvas.draw_idle() self.fig.canvas.flush_events() # ------------------------- # Main streaming function # ------------------------- def stream_onnx_inference( onnx_path: str, config_path: Optional[str] = None, device_index: Optional[int] = None, sample_rate: int = 16000, window_ms: int = 200, hop_ms: int = 100, chunk_size: int = 1600, cpu_only: bool = False, # Histogram params K: int = 72, tau: float = 0.8, smooth_k: int = 1, min_peak_height: float = 0.10, min_window_mass: float = 0.24, min_sep_deg: float = 20.0, min_active_ratio: float = 0.20, max_sources: int = 3, # Event gate params level_delta_on_db: float = 2.5, level_delta_off_db: float = 1.0, level_min_dbfs: float = -60.0, level_ema_alpha: float = 0.05, event_hold_ms: int = 300, min_R_clip: float = 0.18, event_refractory_ms: int = 120, # Onset params onset_alpha: float = 0.05, ): """Stream inference from microphone using ONNX model.""" providers = ['CPUExecutionProvider'] if cpu_only else ['CUDAExecutionProvider', 'CPUExecutionProvider'] infer = ONNXDOAStreaming(onnx_path, config_path, providers=providers) # Override K if provided if K != infer.K: print(f"Warning: K mismatch. Model K={infer.K}, requested K={K}. Using model K.") K = infer.K det = HistDOADetector( K=K, tau=tau, gamma=1.5, smooth_k=smooth_k, window_bins=1, min_peak_height=min_peak_height, min_window_mass=min_window_mass, min_sep_deg=min_sep_deg, min_active_ratio=min_active_ratio, max_sources=max_sources, device="cuda" if not cpu_only and torch.cuda.is_available() else "cpu" ) gate = LevelChangeGate( delta_on_db=level_delta_on_db, delta_off_db=level_delta_off_db, level_min_dbfs=level_min_dbfs, ema_alpha=level_ema_alpha, min_R_clip=min_R_clip, hold_ms=event_hold_ms, refractory_ms=event_refractory_ms ) onset = OnsetDetector(alpha=onset_alpha) visualizer = CurrentLineVisualizer() window_samples = int(sample_rate * window_ms / 1000) hop_samples = int(sample_rate * hop_ms / 1000) p = pyaudio.PyAudio() if device_index is None: for i in range(p.get_device_count()): info = p.get_device_info_by_index(i) name = info['name'].lower() # Check by name first (PulseAudio might hide channels) if 'respeaker' in name or 'seeed' in name or '2886' in name: device_index = i print(f"Auto-detected ReSpeaker at device {i}: {info['name']}") break if device_index is None: print("\n[Audio] Could not auto-detect Respeaker. Use --device-index or --list-devices.\n") p.terminate() return # Check device info device_info = p.get_device_info_by_index(device_index) print(f"Device info: {device_info['name']}") print(f" Max input channels: {device_info['maxInputChannels']}") print(f" Default sample rate: {device_info['defaultSampleRate']:.0f} Hz") # If device shows 0 channels, it's likely managed by PulseAudio # We'll still try to open it - sometimes it works despite the report if device_info['maxInputChannels'] == 0: print(" Warning: Device reports 0 channels (may be managed by PulseAudio)") print(" Attempting to open anyway...") CHANNELS = 6 RAW_CHANNELS = [1, 4, 3, 2] # your requested order FORMAT = pyaudio.paInt16 audio_buffer = np.zeros((4, window_samples), dtype=np.float32) buffer_fill = 0 start_time = time.time() audio_queue = queue.Queue() stream_closed = False def _fill_buffer(in_data, frame_count, time_info, status_flags): if not stream_closed: audio_queue.put(in_data) return None, pyaudio.paContinue try: # Try to open the stream - PyAudio will validate channels stream = p.open( format=FORMAT, channels=CHANNELS, rate=sample_rate, input=True, input_device_index=device_index, frames_per_buffer=chunk_size, stream_callback=_fill_buffer ) print(" Successfully opened audio stream with 6 channels") except Exception as e: print(f"\n[Audio] Could not open input device (index {device_index}).") print(f" Error: {e}") print("\n The ReSpeaker device is likely locked by PulseAudio.") print(" Solutions:") print(" 1. Temporarily stop PulseAudio: pulseaudio --kill") print(" 2. Then restart it after: pulseaudio --start") print(" 3. Or configure PulseAudio to allow direct ALSA access\n") p.terminate() return stream.start_stream() print(f"\n[Streaming] Started. Window: {window_ms}ms, Hop: {hop_ms}ms") print(" Press Ctrl+C to stop.\n") try: while True: try: data = audio_queue.get(timeout=1.0) except queue.Empty: continue chunk_all = chunk_to_floatarray(data, CHANNELS) # (6, N) audio_chunk = chunk_all[RAW_CHANNELS, :] # (4, N) n = audio_chunk.shape[1] if buffer_fill + n <= window_samples: audio_buffer[:, buffer_fill:buffer_fill + n] = audio_chunk buffer_fill += n continue remaining = window_samples - buffer_fill if remaining > 0: audio_buffer[:, buffer_fill:] = audio_chunk[:, :remaining] buffer_fill = window_samples # Inference t0 = time.perf_counter() feats = infer.compute_features(audio_buffer) logits = infer.inference_batch(feats) t_model = (time.perf_counter() - t0) * 1000.0 T = logits.shape[0] energies = frame_rms_energy(audio_buffer, T) flux = spectral_flux_per_frame(audio_buffer, T) flux_recent = float(max(flux[-1], flux[-2] if T >= 2 else 0.0)) flux_z = onset.update_flux(flux_recent) coh = OnsetDetector.last_segment_coherence(audio_buffer, T) # DOA detection (no VAD) t1 = time.perf_counter() det_result = det.detect(logits) t_hist = (time.perf_counter() - t1) * 1000.0 peaks = det_result["peaks"] peaks_count = len(peaks) Rmax = det_result["R_clip"] level = rms_dbfs(audio_buffer) now = time.time() - start_time gate_open, diff_db = gate.update(level_dbfs=level, now_s=now, peaks_count=peaks_count, R_clip_max=Rmax) if gate_open: visualizer.update(peaks) gate_str = "OPEN " else: visualizer.update([]) gate_str = "CLOSED" print(f"[{now:6.2f}s] LVL={level:6.1f} dBFS diff={diff_db:+4.1f} | " f"FLUXz={flux_z:4.2f} COH={coh:4.2f} | " f"GATE={gate_str} | " f"MODEL={t_model:5.1f}ms HIST={t_hist:5.1f}ms | " f"DOA(R={Rmax:.2f}, n={peaks_count})", end="") if peaks: az_str = ", ".join([f"{p['azimuth_deg']:.0f}°" for p in peaks[:3]]) print(f" [{az_str}]") else: print() # Slide buffer audio_buffer[:, :-hop_samples] = audio_buffer[:, hop_samples:] buffer_fill = window_samples - hop_samples if n > remaining: carry = min(n - remaining, hop_samples) if carry > 0: audio_buffer[:, buffer_fill:buffer_fill + carry] = audio_chunk[:, remaining:remaining + carry] buffer_fill += carry except KeyboardInterrupt: print("\n[Streaming] Stopped by user.") finally: stream_closed = True try: stream.stop_stream() stream.close() except Exception: pass p.terminate() plt.close('all') def main(): parser = argparse.ArgumentParser(description="Stream ONNX DOA inference from microphone") parser.add_argument('--onnx', type=str, required=False, help='Path to ONNX model file') parser.add_argument('--config', type=str, default=None, help='Path to config.yaml') parser.add_argument('--device-index', type=int, default=None, help='Audio device index') parser.add_argument('--sample-rate', type=int, default=16000, help='Sample rate (Hz)') parser.add_argument('--window-ms', type=int, default=200, help='Window length (ms)') parser.add_argument('--hop-ms', type=int, default=100, help='Hop length (ms)') parser.add_argument('--chunk-size', type=int, default=1600, help='Audio chunk size') parser.add_argument('--cpu-only', action='store_true', help='Use CPU only') parser.add_argument('--list-devices', action='store_true', help='List available audio devices') # Histogram params parser.add_argument('--K', type=int, default=72, help='Number of azimuth bins') parser.add_argument('--tau', type=float, default=0.8, help='Softmax temperature') parser.add_argument('--smooth-k', type=int, default=1, help='Smoothing kernel size') parser.add_argument('--min-peak-height', type=float, default=0.10, help='Min peak height') parser.add_argument('--min-window-mass', type=float, default=0.24, help='Min window mass') parser.add_argument('--min-sep-deg', type=float, default=20.0, help='Min separation (deg)') parser.add_argument('--min-active-ratio', type=float, default=0.20, help='Min active ratio') parser.add_argument('--max-sources', type=int, default=3, help='Max sources') # Event gate params parser.add_argument('--level-delta-on-db', type=float, default=2.5, help='Level delta on (dB)') parser.add_argument('--level-delta-off-db', type=float, default=1.0, help='Level delta off (dB)') parser.add_argument('--level-min-dbfs', type=float, default=-60.0, help='Min level (dBFS)') parser.add_argument('--level-ema-alpha', type=float, default=0.05, help='Level EMA alpha') parser.add_argument('--event-hold-ms', type=int, default=300, help='Event hold (ms)') parser.add_argument('--min-R-clip', type=float, default=0.18, help='Min R clip') parser.add_argument('--event-refractory-ms', type=int, default=120, help='Event refractory (ms)') # Onset params parser.add_argument('--onset-alpha', type=float, default=0.05, help='Onset EMA alpha') args = parser.parse_args() if args.list_devices: p = pyaudio.PyAudio() print("\nAvailable audio input devices:") print("-" * 80) for i in range(p.get_device_count()): info = p.get_device_info_by_index(i) if info['maxInputChannels'] > 0: print(f"Device {i}: {info['name']}") print(f" Channels: {info['maxInputChannels']}, Sample Rate: {info['defaultSampleRate']:.0f} Hz\n") p.terminate() return if args.onnx is None: parser.error("--onnx is required (unless using --list-devices)") onnx_path = Path(args.onnx) if not onnx_path.exists(): parser.error(f"ONNX model not found: {onnx_path}") stream_onnx_inference( onnx_path=str(onnx_path), config_path=args.config, device_index=args.device_index, sample_rate=args.sample_rate, window_ms=args.window_ms, hop_ms=args.hop_ms, chunk_size=1600, # args.chunk_size, cpu_only=args.cpu_only, K=args.K, tau=args.tau, smooth_k=args.smooth_k, min_peak_height=args.min_peak_height, min_window_mass=args.min_window_mass, min_sep_deg=args.min_sep_deg, min_active_ratio=args.min_active_ratio, max_sources=args.max_sources, level_delta_on_db=args.level_delta_on_db, level_delta_off_db=args.level_delta_off_db, level_min_dbfs=args.level_min_dbfs, level_ema_alpha=args.level_ema_alpha, event_hold_ms=args.event_hold_ms, min_R_clip=args.min_R_clip, event_refractory_ms=args.event_refractory_ms, onset_alpha=args.onset_alpha, ) if __name__ == "__main__": main()