|
|
|
|
|
""" |
|
|
Real-time DOA inference using ONNX model with microphone streaming. |
|
|
Includes histogram-based detection, event gates, and onset detection. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent |
|
|
src_dir = project_root / "src" |
|
|
if str(src_dir) not in sys.path: |
|
|
sys.path.insert(0, str(src_dir)) |
|
|
|
|
|
import math |
|
|
import numpy as np |
|
|
import time |
|
|
import queue |
|
|
import argparse |
|
|
import pyaudio |
|
|
import onnxruntime as ort |
|
|
import yaml |
|
|
from typing import Optional, Dict, List, Tuple |
|
|
import matplotlib.pyplot as plt |
|
|
from matplotlib.patches import Circle |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from mirokai_doa.features import stft_multi, compute_mag_phase_cos_sin |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _angles_deg_np(K: int): |
|
|
bin_size = 360.0 / K |
|
|
deg = (np.arange(K, dtype=np.float32) + 0.5) * bin_size |
|
|
rad = deg * np.pi / 180.0 |
|
|
return deg, np.cos(rad), np.sin(rad), bin_size |
|
|
|
|
|
def _softmax_temp_np(logits: np.ndarray, tau: float = 0.8) -> np.ndarray: |
|
|
exp_logits = np.exp((logits - np.max(logits, axis=-1, keepdims=True)) / tau) |
|
|
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) |
|
|
|
|
|
def _circular_window_sum_np(row: np.ndarray, idx: int, half_w: int) -> float: |
|
|
K = row.size |
|
|
if half_w <= 0: |
|
|
return float(row[idx]) |
|
|
acc = 0.0 |
|
|
for d in range(-half_w, half_w + 1): |
|
|
acc += float(row[(idx + d) % K]) |
|
|
return acc |
|
|
|
|
|
def _parabolic_peak_refine_np(row: np.ndarray, k: int) -> float: |
|
|
K = row.size |
|
|
km1, kp1 = (k - 1) % K, (k + 1) % K |
|
|
y1, y2, y3 = float(row[km1]), float(row[k]), float(row[kp1]) |
|
|
denom = (y1 - 2 * y2 + y3) |
|
|
if abs(denom) < 1e-9: |
|
|
return 0.0 |
|
|
delta = 0.5 * (y1 - y3) / denom |
|
|
return float(max(min(delta, 0.5), -0.5)) |
|
|
|
|
|
def _min_circ_separation_bins(a: int, chosen: List[int], K: int) -> int: |
|
|
if not chosen: |
|
|
return K |
|
|
dmin = K |
|
|
for j in chosen: |
|
|
d = abs(a - j) |
|
|
d = min(d, K - d) |
|
|
dmin = min(dmin, d) |
|
|
return dmin |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def byte_to_float(data: bytes) -> np.ndarray: |
|
|
samples = np.frombuffer(data, dtype=np.int16) |
|
|
return samples.astype(np.float32) / 32768.0 |
|
|
|
|
|
def chunk_to_floatarray(data: bytes, channels: int) -> np.ndarray: |
|
|
float_data = byte_to_float(data) |
|
|
return float_data.reshape(-1, channels).T |
|
|
|
|
|
def rms_dbfs(x: np.ndarray, eps: float = 1e-9) -> float: |
|
|
val = np.sqrt((x * x).mean()) |
|
|
return 20.0 * np.log10(max(val, eps)) |
|
|
|
|
|
def frame_rms_energy(audio_buffer: np.ndarray, T: int) -> np.ndarray: |
|
|
"""Split audio_buffer (C,N) into T equal segments; return per-frame RMS (normalized).""" |
|
|
C, N = audio_buffer.shape |
|
|
if T <= 0: |
|
|
return np.ones(1, dtype=np.float32) |
|
|
edges = np.linspace(0, N, T + 1, dtype=int) |
|
|
e = [] |
|
|
for i in range(T): |
|
|
seg = audio_buffer[:, edges[i]:edges[i+1]] |
|
|
if seg.size == 0: |
|
|
e.append(0.0) |
|
|
else: |
|
|
rms = np.sqrt((seg * seg).mean()) |
|
|
e.append(rms) |
|
|
e = np.asarray(e, dtype=np.float32) |
|
|
e = e / max(e.mean(), 1e-6) |
|
|
return e |
|
|
|
|
|
def spectral_flux_per_frame(audio_buffer: np.ndarray, T: int) -> np.ndarray: |
|
|
"""Compute per-frame spectral flux across T segments from mono mix.""" |
|
|
C, N = audio_buffer.shape |
|
|
if T <= 1: |
|
|
return np.zeros((T,), dtype=np.float32) |
|
|
mono = audio_buffer.mean(axis=0) |
|
|
edges = np.linspace(0, N, T + 1, dtype=int) |
|
|
mags = [] |
|
|
for i in range(T): |
|
|
seg = mono[edges[i]:edges[i+1]] |
|
|
if seg.size == 0: |
|
|
mags.append(np.zeros(1, dtype=np.float32)) |
|
|
continue |
|
|
win = np.hanning(len(seg)) if len(seg) > 8 else np.ones_like(seg) |
|
|
S = np.fft.rfft(seg * win, n=len(seg)) |
|
|
mags.append(np.abs(S).astype(np.float32)) |
|
|
flux = np.zeros(T, dtype=np.float32) |
|
|
for t in range(1, T): |
|
|
a = mags[t-1] |
|
|
b = mags[t] |
|
|
L = min(len(a), len(b)) |
|
|
if L == 0: |
|
|
flux[t] = 0.0 |
|
|
continue |
|
|
diff = b[:L] - a[:L] |
|
|
pos = np.maximum(diff, 0.0) |
|
|
denom = np.sum(b[:L]) + 1e-6 |
|
|
flux[t] = float(np.sum(pos) / denom) |
|
|
return flux |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OnsetDetector: |
|
|
def __init__(self, alpha: float = 0.05): |
|
|
self.alpha = float(alpha) |
|
|
self.mu = 0.0 |
|
|
self.var = 1.0 |
|
|
self.inited = False |
|
|
|
|
|
def update_flux(self, flux_recent: float) -> float: |
|
|
if not self.inited: |
|
|
self.mu = flux_recent |
|
|
self.var = 1e-3 + abs(flux_recent) |
|
|
self.inited = True |
|
|
delta = flux_recent - self.mu |
|
|
self.mu += self.alpha * delta |
|
|
self.var = (1 - self.alpha) * self.var + self.alpha * delta * delta |
|
|
sigma = max(np.sqrt(self.var), 1e-6) |
|
|
z = (flux_recent - self.mu) / sigma |
|
|
return float(z) |
|
|
|
|
|
@staticmethod |
|
|
def last_segment_coherence(audio_buffer: np.ndarray, T: int, |
|
|
pairs: List[Tuple[int,int]] = [(0,1),(0,2),(0,3)]) -> float: |
|
|
C, N = audio_buffer.shape |
|
|
if T < 1: |
|
|
return 0.0 |
|
|
edges = np.linspace(0, N, T + 1, dtype=int) |
|
|
s0, s1 = int(edges[-2]), int(edges[-1]) |
|
|
seg = audio_buffer[:, s0:s1] |
|
|
if seg.shape[1] < 16: |
|
|
return 0.0 |
|
|
rmax = 0.0 |
|
|
for (i,j) in pairs: |
|
|
xi = seg[i] - seg[i].mean() |
|
|
xj = seg[j] - seg[j].mean() |
|
|
denom = (np.linalg.norm(xi) * np.linalg.norm(xj) + 1e-9) |
|
|
r = float(np.dot(xi, xj) / denom) |
|
|
rmax = max(rmax, abs(r)) |
|
|
return rmax |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HistDOADetector: |
|
|
def __init__( |
|
|
self, |
|
|
K: int = 72, |
|
|
tau: float = 0.8, |
|
|
gamma: float = 1.5, |
|
|
smooth_k: int = 1, |
|
|
window_bins: int = 1, |
|
|
min_peak_height: float = 0.10, |
|
|
min_window_mass: float = 0.24, |
|
|
min_sep_deg: float = 20.0, |
|
|
min_active_ratio: float = 0.20, |
|
|
max_sources: int = 3, |
|
|
device: str = "cpu", |
|
|
): |
|
|
self.K = int(K) |
|
|
self.tau = float(tau) |
|
|
self.gamma = float(gamma) |
|
|
self.smooth_k = int(smooth_k) |
|
|
self.window_bins = int(window_bins) |
|
|
self.min_peak_height = float(min_peak_height) |
|
|
self.min_window_mass = float(min_window_mass) |
|
|
self.min_sep_deg = float(min_sep_deg) |
|
|
self.min_active_ratio = float(min_active_ratio) |
|
|
self.max_sources = int(max_sources) |
|
|
self.device = torch.device(device) |
|
|
self._deg, self._cos, self._sin, self._bin_size = self._angles_deg(self.K) |
|
|
|
|
|
def _angles_deg(self, K: int): |
|
|
bin_size = 360.0 / K |
|
|
deg = torch.arange(K, device=self.device, dtype=torch.float32) + 0.5 |
|
|
deg = deg * bin_size |
|
|
rad = deg * math.pi / 180.0 |
|
|
return deg, torch.cos(rad), torch.sin(rad), bin_size |
|
|
|
|
|
def _aggregate_histogram(self, logits: np.ndarray, mask: np.ndarray) -> Tuple[np.ndarray, float, float]: |
|
|
"""Aggregate histogram from logits and VAD mask.""" |
|
|
logits_t = torch.from_numpy(logits).float().to(self.device) |
|
|
mask_t = torch.from_numpy(mask).float().to(self.device) |
|
|
|
|
|
probs = F.softmax(logits_t / self.tau, dim=-1) |
|
|
T = probs.shape[0] |
|
|
m = mask_t |
|
|
|
|
|
|
|
|
x = torch.matmul(probs, self._cos) |
|
|
y = torch.matmul(probs, self._sin) |
|
|
R_t = torch.clamp(torch.sqrt(x * x + y * y), 0, 1) |
|
|
w = m * (R_t ** self.gamma) |
|
|
|
|
|
if w.sum() <= 0: |
|
|
w = torch.ones_like(w) * 1e-6 |
|
|
|
|
|
hist = torch.matmul(w, probs) |
|
|
hist = hist / hist.sum().clamp_min(1e-8) |
|
|
|
|
|
if self.smooth_k > 0: |
|
|
s = self.smooth_k |
|
|
pad = torch.cat([hist[-s:], hist, hist[:s]], dim=0).view(1, 1, -1) |
|
|
kernel = torch.ones(1, 1, 2 * s + 1, device=self.device) / (2 * s + 1) |
|
|
hist = F.conv1d(pad, kernel, padding=0).view(-1) |
|
|
|
|
|
X = torch.dot(hist, self._cos) |
|
|
Y = torch.dot(hist, self._sin) |
|
|
R_clip = float(torch.sqrt(X * X + Y * Y).item()) |
|
|
active_ratio = float(m.mean().item()) |
|
|
return hist.detach().cpu().numpy(), active_ratio, R_clip |
|
|
|
|
|
def _pick_peaks(self, hist: np.ndarray) -> List[Dict[str, float]]: |
|
|
"""Pick peaks from histogram.""" |
|
|
hist_t = torch.from_numpy(hist).float() |
|
|
K = self.K |
|
|
bin_size = self._bin_size |
|
|
|
|
|
left = torch.roll(hist_t, 1, 0) |
|
|
right = torch.roll(hist_t, -1, 0) |
|
|
cand_idxs = ((hist_t > left) & (hist_t > right)).nonzero(as_tuple=False).flatten().tolist() |
|
|
cand_idxs.sort(key=lambda i: float(hist_t[i].item()), reverse=True) |
|
|
|
|
|
chosen, out = [], [] |
|
|
min_sep_bins = max(1, int(round(self.min_sep_deg / bin_size))) |
|
|
|
|
|
for idx in cand_idxs: |
|
|
if _min_circ_separation_bins(idx, chosen, K) < min_sep_bins: |
|
|
continue |
|
|
if float(hist_t[idx].item()) < self.min_peak_height: |
|
|
continue |
|
|
mass = _circular_window_sum_np(hist, idx, self.window_bins) |
|
|
if mass < self.min_window_mass: |
|
|
continue |
|
|
delta = _parabolic_peak_refine_np(hist, idx) |
|
|
angle_deg = ((idx + 0.5 + delta) * bin_size) % 360.0 |
|
|
out.append({"azimuth_deg": angle_deg, "score": float(mass)}) |
|
|
chosen.append(idx) |
|
|
if len(out) >= self.max_sources: |
|
|
break |
|
|
return out |
|
|
|
|
|
def detect(self, logits: np.ndarray) -> Dict[str, any]: |
|
|
"""Detect DOA from logits (no VAD separation).""" |
|
|
|
|
|
mask = np.ones(logits.shape[0], dtype=np.float32) |
|
|
|
|
|
hist, active_ratio, R_clip = self._aggregate_histogram(logits, mask) |
|
|
|
|
|
peaks = self._pick_peaks(hist) if active_ratio >= self.min_active_ratio else [] |
|
|
|
|
|
bins_deg = (np.arange(self.K) + 0.5) * (360.0 / self.K) |
|
|
return { |
|
|
"peaks": peaks, |
|
|
"active_ratio": active_ratio, |
|
|
"R_clip": R_clip, |
|
|
"hist": hist, |
|
|
"bins_deg": bins_deg, |
|
|
"has_event": bool(peaks), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LevelChangeGate: |
|
|
def __init__( |
|
|
self, |
|
|
delta_on_db: float = 2.5, |
|
|
delta_off_db: float = 1.0, |
|
|
level_min_dbfs: float = -60.0, |
|
|
ema_alpha: float = 0.05, |
|
|
min_R_clip: float = 0.18, |
|
|
hold_ms: int = 300, |
|
|
refractory_ms: int = 120 |
|
|
): |
|
|
self.delta_on_db = float(delta_on_db) |
|
|
self.delta_off_db = float(delta_off_db) |
|
|
self.level_min_dbfs = float(level_min_dbfs) |
|
|
self.ema_alpha = float(ema_alpha) |
|
|
self.min_R_clip = float(min_R_clip) |
|
|
self.hold_s = float(hold_ms) / 1000.0 |
|
|
self.refractory_s = float(refractory_ms) / 1000.0 |
|
|
self.bg_dbfs = None |
|
|
self.active = False |
|
|
self.last_change_time = 0.0 |
|
|
|
|
|
def update(self, level_dbfs: float, now_s: float, |
|
|
peaks_count: int, R_clip_max: float): |
|
|
if self.bg_dbfs is None: |
|
|
self.bg_dbfs = level_dbfs |
|
|
diff_db = level_dbfs - self.bg_dbfs |
|
|
|
|
|
want_open = ( |
|
|
(now_s - self.last_change_time) >= self.refractory_s and |
|
|
((level_dbfs > self.level_min_dbfs and diff_db >= self.delta_on_db) or |
|
|
(peaks_count > 0 and R_clip_max >= self.min_R_clip)) |
|
|
) |
|
|
|
|
|
if not self.active: |
|
|
if want_open: |
|
|
self.active = True |
|
|
self.last_change_time = now_s |
|
|
else: |
|
|
if (now_s - self.last_change_time) >= self.hold_s: |
|
|
want_close = ( |
|
|
(diff_db <= self.delta_off_db) and |
|
|
(peaks_count == 0 or R_clip_max < self.min_R_clip) |
|
|
) |
|
|
if want_close: |
|
|
self.active = False |
|
|
self.last_change_time = now_s |
|
|
|
|
|
self.bg_dbfs = (1.0 - self.ema_alpha) * self.bg_dbfs + self.ema_alpha * level_dbfs |
|
|
return self.active, diff_db |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ONNXDOAStreaming: |
|
|
def __init__( |
|
|
self, |
|
|
onnx_path: str, |
|
|
config_path: Optional[str] = None, |
|
|
providers: Optional[list] = None |
|
|
): |
|
|
if config_path is None: |
|
|
config_path = project_root / "configs" / "train.yaml" |
|
|
with open(config_path, 'r') as f: |
|
|
self.config = yaml.safe_load(f) |
|
|
|
|
|
self.features_cfg = self.config.get('features', {}) |
|
|
self.sr = self.features_cfg.get('sr', 16000) |
|
|
self.win_s = self.features_cfg.get('win_s', 0.032) |
|
|
self.hop_s = self.features_cfg.get('hop_s', 0.010) |
|
|
self.nfft = self.features_cfg.get('nfft', 1024) |
|
|
self.K = self.features_cfg.get('K', 72) |
|
|
|
|
|
if providers is None: |
|
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
|
|
|
|
|
sess_options = ort.SessionOptions() |
|
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
|
|
|
self.session = ort.InferenceSession(onnx_path, sess_options=sess_options, providers=providers) |
|
|
self.input_name = self.session.get_inputs()[0].name |
|
|
self.output_name = self.session.get_outputs()[0].name |
|
|
input_shape = self.session.get_inputs()[0].shape |
|
|
self.is_doa_model = input_shape[-1] == 513 if len(input_shape) == 4 else False |
|
|
|
|
|
print(f"ONNX Model loaded: {onnx_path}") |
|
|
print(f" Input shape: {input_shape}") |
|
|
print(f" Model type: {'DoAEstimator' if self.is_doa_model else 'TFPoolClassifierNoCond'}") |
|
|
print(f" Providers: {self.session.get_providers()}") |
|
|
|
|
|
def compute_features(self, mixture: np.ndarray) -> np.ndarray: |
|
|
if mixture.ndim == 1: |
|
|
raise ValueError("Mixture must be multichannel (4 channels)") |
|
|
if mixture.shape[0] != 4 and mixture.shape[1] == 4: |
|
|
mixture = mixture.T |
|
|
|
|
|
if mixture.shape[0] != 4: |
|
|
raise ValueError(f"Expected 4 channels, got {mixture.shape[0]}") |
|
|
|
|
|
x4 = mixture.astype(np.float32) |
|
|
X, freqs, times = stft_multi(x4.T, fs=self.sr, win_s=self.win_s, hop_s=self.hop_s, |
|
|
nfft=self.nfft, window="hann", center=True, pad_mode="reflect") |
|
|
feats = compute_mag_phase_cos_sin(X, dtype=np.float32) |
|
|
return feats |
|
|
|
|
|
def inference_batch(self, feats: np.ndarray, batch_size: int = 25) -> np.ndarray: |
|
|
T_frames, C_feat, F = feats.shape |
|
|
assert C_feat == 12, f"Expected 12 feature channels, got {C_feat}" |
|
|
|
|
|
all_logits = [] |
|
|
for start_idx in range(0, T_frames, batch_size): |
|
|
end_idx = min(start_idx + batch_size, T_frames) |
|
|
batch_feats = feats[start_idx:end_idx] |
|
|
batch_T = batch_feats.shape[0] |
|
|
|
|
|
if batch_T < batch_size: |
|
|
padding = np.zeros((batch_size - batch_T, C_feat, F), dtype=batch_feats.dtype) |
|
|
batch_feats = np.concatenate([batch_feats, padding], axis=0) |
|
|
|
|
|
feats_tensor = batch_feats.transpose(1, 0, 2)[np.newaxis, ...] |
|
|
outputs = self.session.run([self.output_name], {self.input_name: feats_tensor.astype(np.float32)}) |
|
|
batch_logits = outputs[0] |
|
|
|
|
|
if batch_logits.ndim == 2: |
|
|
if batch_logits.shape[0] == 1 and batch_logits.shape[1] == self.K: |
|
|
batch_logits = np.tile(batch_logits, (batch_T, 1)) |
|
|
elif batch_logits.shape[0] == 1: |
|
|
batch_logits = batch_logits[0] |
|
|
else: |
|
|
batch_logits = batch_logits[:batch_T] |
|
|
elif batch_logits.ndim == 3: |
|
|
batch_logits = batch_logits[0, :batch_T] |
|
|
|
|
|
all_logits.append(batch_logits) |
|
|
|
|
|
return np.concatenate(all_logits, axis=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CurrentLineVisualizer: |
|
|
def __init__(self, title: str = "Current DOA"): |
|
|
self.fig = plt.figure(figsize=(7.5, 7.5)) |
|
|
self.ax = self.fig.add_subplot(111, projection='polar') |
|
|
self._setup_axes(title) |
|
|
plt.ion() |
|
|
plt.show(block=False) |
|
|
|
|
|
def _setup_axes(self, title: str): |
|
|
self.ax.clear() |
|
|
self.ax.set_title(title, fontsize=13, fontweight='bold', pad=16) |
|
|
self.ax.set_theta_zero_location('N') |
|
|
self.ax.set_theta_direction(-1) |
|
|
self.ax.set_thetalim(0, 2*np.pi) |
|
|
self.ax.set_ylim(0, 1.05) |
|
|
self.ax.set_yticklabels([]) |
|
|
self.ax.add_patch(Circle((0, 0), 1.0, fill=False, color='gray', linestyle='--', linewidth=1, alpha=0.5)) |
|
|
self.ax.grid(alpha=0.2) |
|
|
|
|
|
def update(self, peaks: List[Dict]): |
|
|
self._setup_axes("Current DOA") |
|
|
|
|
|
for pk in peaks[:3]: |
|
|
az = float(pk["azimuth_deg"]) |
|
|
sc = float(pk.get("score", 0.2)) |
|
|
lw = 2.0 + 5.0 * float(np.clip(sc, 0.0, 0.6)) |
|
|
theta = np.deg2rad(az) |
|
|
self.ax.plot([theta, theta], [0.0, 1.0], color='tab:green', linewidth=lw, solid_capstyle='round') |
|
|
self.ax.text(theta, 1.02, f"{az:.0f}°", ha='center', va='bottom', fontsize=10, |
|
|
color='tab:green', fontweight='bold') |
|
|
|
|
|
self.fig.canvas.draw_idle() |
|
|
self.fig.canvas.flush_events() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stream_onnx_inference( |
|
|
onnx_path: str, |
|
|
config_path: Optional[str] = None, |
|
|
device_index: Optional[int] = None, |
|
|
sample_rate: int = 16000, |
|
|
window_ms: int = 200, |
|
|
hop_ms: int = 100, |
|
|
chunk_size: int = 1600, |
|
|
cpu_only: bool = False, |
|
|
|
|
|
K: int = 72, |
|
|
tau: float = 0.8, |
|
|
smooth_k: int = 1, |
|
|
min_peak_height: float = 0.10, |
|
|
min_window_mass: float = 0.24, |
|
|
min_sep_deg: float = 20.0, |
|
|
min_active_ratio: float = 0.20, |
|
|
max_sources: int = 3, |
|
|
|
|
|
level_delta_on_db: float = 2.5, |
|
|
level_delta_off_db: float = 1.0, |
|
|
level_min_dbfs: float = -60.0, |
|
|
level_ema_alpha: float = 0.05, |
|
|
event_hold_ms: int = 300, |
|
|
min_R_clip: float = 0.18, |
|
|
event_refractory_ms: int = 120, |
|
|
|
|
|
onset_alpha: float = 0.05, |
|
|
): |
|
|
"""Stream inference from microphone using ONNX model.""" |
|
|
|
|
|
providers = ['CPUExecutionProvider'] if cpu_only else ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
|
|
infer = ONNXDOAStreaming(onnx_path, config_path, providers=providers) |
|
|
|
|
|
|
|
|
if K != infer.K: |
|
|
print(f"Warning: K mismatch. Model K={infer.K}, requested K={K}. Using model K.") |
|
|
K = infer.K |
|
|
|
|
|
det = HistDOADetector( |
|
|
K=K, tau=tau, gamma=1.5, smooth_k=smooth_k, |
|
|
window_bins=1, min_peak_height=min_peak_height, min_window_mass=min_window_mass, |
|
|
min_sep_deg=min_sep_deg, min_active_ratio=min_active_ratio, max_sources=max_sources, |
|
|
device="cuda" if not cpu_only and torch.cuda.is_available() else "cpu" |
|
|
) |
|
|
|
|
|
gate = LevelChangeGate( |
|
|
delta_on_db=level_delta_on_db, delta_off_db=level_delta_off_db, |
|
|
level_min_dbfs=level_min_dbfs, ema_alpha=level_ema_alpha, |
|
|
min_R_clip=min_R_clip, |
|
|
hold_ms=event_hold_ms, refractory_ms=event_refractory_ms |
|
|
) |
|
|
|
|
|
onset = OnsetDetector(alpha=onset_alpha) |
|
|
visualizer = CurrentLineVisualizer() |
|
|
|
|
|
window_samples = int(sample_rate * window_ms / 1000) |
|
|
hop_samples = int(sample_rate * hop_ms / 1000) |
|
|
|
|
|
p = pyaudio.PyAudio() |
|
|
|
|
|
if device_index is None: |
|
|
for i in range(p.get_device_count()): |
|
|
info = p.get_device_info_by_index(i) |
|
|
name = info['name'].lower() |
|
|
|
|
|
if 'respeaker' in name or 'seeed' in name or '2886' in name: |
|
|
device_index = i |
|
|
print(f"Auto-detected ReSpeaker at device {i}: {info['name']}") |
|
|
break |
|
|
|
|
|
if device_index is None: |
|
|
print("\n[Audio] Could not auto-detect Respeaker. Use --device-index or --list-devices.\n") |
|
|
p.terminate() |
|
|
return |
|
|
|
|
|
|
|
|
device_info = p.get_device_info_by_index(device_index) |
|
|
print(f"Device info: {device_info['name']}") |
|
|
print(f" Max input channels: {device_info['maxInputChannels']}") |
|
|
print(f" Default sample rate: {device_info['defaultSampleRate']:.0f} Hz") |
|
|
|
|
|
|
|
|
|
|
|
if device_info['maxInputChannels'] == 0: |
|
|
print(" Warning: Device reports 0 channels (may be managed by PulseAudio)") |
|
|
print(" Attempting to open anyway...") |
|
|
|
|
|
CHANNELS = 6 |
|
|
RAW_CHANNELS = [1, 4, 3, 2] |
|
|
FORMAT = pyaudio.paInt16 |
|
|
|
|
|
audio_buffer = np.zeros((4, window_samples), dtype=np.float32) |
|
|
buffer_fill = 0 |
|
|
start_time = time.time() |
|
|
|
|
|
audio_queue = queue.Queue() |
|
|
stream_closed = False |
|
|
|
|
|
def _fill_buffer(in_data, frame_count, time_info, status_flags): |
|
|
if not stream_closed: |
|
|
audio_queue.put(in_data) |
|
|
return None, pyaudio.paContinue |
|
|
|
|
|
try: |
|
|
|
|
|
stream = p.open( |
|
|
format=FORMAT, |
|
|
channels=CHANNELS, |
|
|
rate=sample_rate, |
|
|
input=True, |
|
|
input_device_index=device_index, |
|
|
frames_per_buffer=chunk_size, |
|
|
stream_callback=_fill_buffer |
|
|
) |
|
|
print(" Successfully opened audio stream with 6 channels") |
|
|
except Exception as e: |
|
|
print(f"\n[Audio] Could not open input device (index {device_index}).") |
|
|
print(f" Error: {e}") |
|
|
print("\n The ReSpeaker device is likely locked by PulseAudio.") |
|
|
print(" Solutions:") |
|
|
print(" 1. Temporarily stop PulseAudio: pulseaudio --kill") |
|
|
print(" 2. Then restart it after: pulseaudio --start") |
|
|
print(" 3. Or configure PulseAudio to allow direct ALSA access\n") |
|
|
p.terminate() |
|
|
return |
|
|
|
|
|
stream.start_stream() |
|
|
print(f"\n[Streaming] Started. Window: {window_ms}ms, Hop: {hop_ms}ms") |
|
|
print(" Press Ctrl+C to stop.\n") |
|
|
|
|
|
try: |
|
|
while True: |
|
|
try: |
|
|
data = audio_queue.get(timeout=1.0) |
|
|
except queue.Empty: |
|
|
continue |
|
|
|
|
|
chunk_all = chunk_to_floatarray(data, CHANNELS) |
|
|
audio_chunk = chunk_all[RAW_CHANNELS, :] |
|
|
n = audio_chunk.shape[1] |
|
|
|
|
|
if buffer_fill + n <= window_samples: |
|
|
audio_buffer[:, buffer_fill:buffer_fill + n] = audio_chunk |
|
|
buffer_fill += n |
|
|
continue |
|
|
|
|
|
remaining = window_samples - buffer_fill |
|
|
if remaining > 0: |
|
|
audio_buffer[:, buffer_fill:] = audio_chunk[:, :remaining] |
|
|
buffer_fill = window_samples |
|
|
|
|
|
|
|
|
t0 = time.perf_counter() |
|
|
feats = infer.compute_features(audio_buffer) |
|
|
logits = infer.inference_batch(feats) |
|
|
t_model = (time.perf_counter() - t0) * 1000.0 |
|
|
|
|
|
T = logits.shape[0] |
|
|
energies = frame_rms_energy(audio_buffer, T) |
|
|
flux = spectral_flux_per_frame(audio_buffer, T) |
|
|
flux_recent = float(max(flux[-1], flux[-2] if T >= 2 else 0.0)) |
|
|
flux_z = onset.update_flux(flux_recent) |
|
|
coh = OnsetDetector.last_segment_coherence(audio_buffer, T) |
|
|
|
|
|
|
|
|
t1 = time.perf_counter() |
|
|
det_result = det.detect(logits) |
|
|
t_hist = (time.perf_counter() - t1) * 1000.0 |
|
|
|
|
|
peaks = det_result["peaks"] |
|
|
peaks_count = len(peaks) |
|
|
Rmax = det_result["R_clip"] |
|
|
|
|
|
level = rms_dbfs(audio_buffer) |
|
|
now = time.time() - start_time |
|
|
|
|
|
gate_open, diff_db = gate.update(level_dbfs=level, now_s=now, |
|
|
peaks_count=peaks_count, R_clip_max=Rmax) |
|
|
|
|
|
if gate_open: |
|
|
visualizer.update(peaks) |
|
|
gate_str = "OPEN " |
|
|
else: |
|
|
visualizer.update([]) |
|
|
gate_str = "CLOSED" |
|
|
|
|
|
print(f"[{now:6.2f}s] LVL={level:6.1f} dBFS diff={diff_db:+4.1f} | " |
|
|
f"FLUXz={flux_z:4.2f} COH={coh:4.2f} | " |
|
|
f"GATE={gate_str} | " |
|
|
f"MODEL={t_model:5.1f}ms HIST={t_hist:5.1f}ms | " |
|
|
f"DOA(R={Rmax:.2f}, n={peaks_count})", end="") |
|
|
if peaks: |
|
|
az_str = ", ".join([f"{p['azimuth_deg']:.0f}°" for p in peaks[:3]]) |
|
|
print(f" [{az_str}]") |
|
|
else: |
|
|
print() |
|
|
|
|
|
|
|
|
audio_buffer[:, :-hop_samples] = audio_buffer[:, hop_samples:] |
|
|
buffer_fill = window_samples - hop_samples |
|
|
|
|
|
if n > remaining: |
|
|
carry = min(n - remaining, hop_samples) |
|
|
if carry > 0: |
|
|
audio_buffer[:, buffer_fill:buffer_fill + carry] = audio_chunk[:, remaining:remaining + carry] |
|
|
buffer_fill += carry |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n[Streaming] Stopped by user.") |
|
|
finally: |
|
|
stream_closed = True |
|
|
try: |
|
|
stream.stop_stream() |
|
|
stream.close() |
|
|
except Exception: |
|
|
pass |
|
|
p.terminate() |
|
|
plt.close('all') |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Stream ONNX DOA inference from microphone") |
|
|
parser.add_argument('--onnx', type=str, required=False, help='Path to ONNX model file') |
|
|
parser.add_argument('--config', type=str, default=None, help='Path to config.yaml') |
|
|
parser.add_argument('--device-index', type=int, default=None, help='Audio device index') |
|
|
parser.add_argument('--sample-rate', type=int, default=16000, help='Sample rate (Hz)') |
|
|
parser.add_argument('--window-ms', type=int, default=200, help='Window length (ms)') |
|
|
parser.add_argument('--hop-ms', type=int, default=100, help='Hop length (ms)') |
|
|
parser.add_argument('--chunk-size', type=int, default=1600, help='Audio chunk size') |
|
|
parser.add_argument('--cpu-only', action='store_true', help='Use CPU only') |
|
|
parser.add_argument('--list-devices', action='store_true', help='List available audio devices') |
|
|
|
|
|
|
|
|
parser.add_argument('--K', type=int, default=72, help='Number of azimuth bins') |
|
|
parser.add_argument('--tau', type=float, default=0.8, help='Softmax temperature') |
|
|
parser.add_argument('--smooth-k', type=int, default=1, help='Smoothing kernel size') |
|
|
parser.add_argument('--min-peak-height', type=float, default=0.10, help='Min peak height') |
|
|
parser.add_argument('--min-window-mass', type=float, default=0.24, help='Min window mass') |
|
|
parser.add_argument('--min-sep-deg', type=float, default=20.0, help='Min separation (deg)') |
|
|
parser.add_argument('--min-active-ratio', type=float, default=0.20, help='Min active ratio') |
|
|
parser.add_argument('--max-sources', type=int, default=3, help='Max sources') |
|
|
|
|
|
|
|
|
parser.add_argument('--level-delta-on-db', type=float, default=2.5, help='Level delta on (dB)') |
|
|
parser.add_argument('--level-delta-off-db', type=float, default=1.0, help='Level delta off (dB)') |
|
|
parser.add_argument('--level-min-dbfs', type=float, default=-60.0, help='Min level (dBFS)') |
|
|
parser.add_argument('--level-ema-alpha', type=float, default=0.05, help='Level EMA alpha') |
|
|
parser.add_argument('--event-hold-ms', type=int, default=300, help='Event hold (ms)') |
|
|
parser.add_argument('--min-R-clip', type=float, default=0.18, help='Min R clip') |
|
|
parser.add_argument('--event-refractory-ms', type=int, default=120, help='Event refractory (ms)') |
|
|
|
|
|
|
|
|
parser.add_argument('--onset-alpha', type=float, default=0.05, help='Onset EMA alpha') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.list_devices: |
|
|
p = pyaudio.PyAudio() |
|
|
print("\nAvailable audio input devices:") |
|
|
print("-" * 80) |
|
|
for i in range(p.get_device_count()): |
|
|
info = p.get_device_info_by_index(i) |
|
|
if info['maxInputChannels'] > 0: |
|
|
print(f"Device {i}: {info['name']}") |
|
|
print(f" Channels: {info['maxInputChannels']}, Sample Rate: {info['defaultSampleRate']:.0f} Hz\n") |
|
|
p.terminate() |
|
|
return |
|
|
|
|
|
if args.onnx is None: |
|
|
parser.error("--onnx is required (unless using --list-devices)") |
|
|
|
|
|
onnx_path = Path(args.onnx) |
|
|
if not onnx_path.exists(): |
|
|
parser.error(f"ONNX model not found: {onnx_path}") |
|
|
|
|
|
stream_onnx_inference( |
|
|
onnx_path=str(onnx_path), |
|
|
config_path=args.config, |
|
|
device_index=args.device_index, |
|
|
sample_rate=args.sample_rate, |
|
|
window_ms=args.window_ms, |
|
|
hop_ms=args.hop_ms, |
|
|
chunk_size=1600, |
|
|
cpu_only=args.cpu_only, |
|
|
K=args.K, |
|
|
tau=args.tau, |
|
|
smooth_k=args.smooth_k, |
|
|
min_peak_height=args.min_peak_height, |
|
|
min_window_mass=args.min_window_mass, |
|
|
min_sep_deg=args.min_sep_deg, |
|
|
min_active_ratio=args.min_active_ratio, |
|
|
max_sources=args.max_sources, |
|
|
level_delta_on_db=args.level_delta_on_db, |
|
|
level_delta_off_db=args.level_delta_off_db, |
|
|
level_min_dbfs=args.level_min_dbfs, |
|
|
level_ema_alpha=args.level_ema_alpha, |
|
|
event_hold_ms=args.event_hold_ms, |
|
|
min_R_clip=args.min_R_clip, |
|
|
event_refractory_ms=args.event_refractory_ms, |
|
|
onset_alpha=args.onset_alpha, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|