marionette / tests /audio_analysis.py
RemiFabre
Refactor marionette into modules, fix audio sync, improve tests
dbc544f
"""Audio analysis utilities for sync testing.
Generates test audio with beep markers, creates antenna collision
trajectories, and detects both beep onsets and transient events
in recorded audio for measuring audio-motion sync accuracy.
"""
from __future__ import annotations
import numpy as np
# ──────── Test Signal Generation ──────────────────────────────────
# Non-periodic beep times so large offsets can't accidentally align.
DEFAULT_BEEP_TIMES = [1.0, 2.5, 4.0, 5.0, 7.0]
def generate_sync_test_audio(
sr: int = 48000,
duration: float = 8.0,
beep_times: list[float] | None = None,
beep_freq: float = 1000.0,
beep_duration: float = 0.1,
) -> tuple[np.ndarray, list[float]]:
"""Generate audio with tonal beeps at known timestamps.
Returns (audio_data, beep_timestamps).
Beeps are sine waves with fade-in/out to avoid clicks.
"""
if beep_times is None:
beep_times = DEFAULT_BEEP_TIMES
n_total = int(sr * duration)
audio = np.zeros(n_total, dtype=np.float32)
for t in beep_times:
start_sample = int(t * sr)
n_beep = int(beep_duration * sr)
if start_sample + n_beep > n_total:
continue
t_arr = np.arange(n_beep, dtype=np.float32) / sr
beep = 0.5 * np.sin(2 * np.pi * beep_freq * t_arr).astype(np.float32)
# Fade in/out (5ms each)
fade = int(0.005 * sr)
if fade > 0 and 2 * fade < n_beep:
beep[:fade] *= np.linspace(0, 1, fade, dtype=np.float32)
beep[-fade:] *= np.linspace(1, 0, fade, dtype=np.float32)
audio[start_sample : start_sample + n_beep] += beep
return audio, list(beep_times)
def generate_collision_trajectory(
beep_times: list[float],
duration: float,
motion_sr: int = 100,
) -> tuple[list[float], list[dict]]:
"""Generate frames where antennas collide at each beep time.
Antennas start apart (Β±0.3 rad β‰ˆ Β±17Β°) and slam together
(0.0 rad) at each beep timestamp, then return apart.
Each collision takes ~200ms (100ms approach + 100ms return).
Returns (timestamps, frames).
"""
from scipy.spatial.transform import Rotation as R
n = int(duration * motion_sr)
dt = 1.0 / motion_sr
rest_pos = 0.3 # rad, antennas apart
timestamps = []
frames = []
identity_pose = np.eye(4)
for i in range(n):
t = i * dt
timestamps.append(t)
# Compute antenna position: slam together at each beep time
antenna_val = rest_pos
for bt in beep_times:
approach_start = bt - 0.1 # 100ms before collision
return_end = bt + 0.1 # 100ms after collision
if approach_start <= t <= bt:
# Approaching: ease from rest to 0
frac = (t - approach_start) / 0.1
antenna_val = rest_pos * (1.0 - frac)
break
elif bt < t <= return_end:
# Returning: ease from 0 to rest
frac = (t - bt) / 0.1
antenna_val = rest_pos * frac
break
frames.append({
"head": identity_pose.tolist(),
"antennas": [-antenna_val, antenna_val],
"body_yaw": 0.0,
"check_collision": False,
})
return timestamps, frames
# ──────── Audio Analysis ──────────────────────────────────────────
def detect_beep_onsets(
audio: np.ndarray,
sr: int,
freq: float = 1000.0,
bandwidth: float = 200.0,
threshold_db: float = -20.0,
min_separation: float = 0.3,
) -> list[float]:
"""Detect onset times of tonal beeps using bandpass + envelope.
1. Bandpass filter around target frequency
2. Compute amplitude envelope via rectification + lowpass
3. Rising-edge threshold crossing for onset detection
"""
from scipy.signal import butter, sosfilt
# Bandpass around beep frequency
low = max(20, freq - bandwidth / 2) / (sr / 2)
high = min(0.99, (freq + bandwidth / 2) / (sr / 2))
sos = butter(4, [low, high], btype="bandpass", output="sos")
filtered = sosfilt(sos, audio.astype(np.float64))
# Amplitude envelope: rectify + lowpass at 50Hz
envelope = np.abs(filtered)
lp_freq = min(50.0 / (sr / 2), 0.99)
sos_lp = butter(2, lp_freq, btype="lowpass", output="sos")
envelope = sosfilt(sos_lp, envelope)
# Normalize and threshold
peak_val = np.max(envelope)
if peak_val < 1e-10:
return []
envelope /= peak_val
threshold = 10 ** (threshold_db / 20)
# Rising-edge threshold crossings (onset = first sample above threshold)
above = envelope > threshold
edges = np.diff(above.astype(np.int8))
onset_indices = np.where(edges > 0)[0] + 1
# Filter by minimum separation
min_distance = int(min_separation * sr)
if len(onset_indices) > 1:
filtered_indices = [onset_indices[0]]
for idx in onset_indices[1:]:
if idx - filtered_indices[-1] >= min_distance:
filtered_indices.append(idx)
onset_indices = filtered_indices
return [idx / sr for idx in onset_indices]
def detect_transient_onsets(
audio: np.ndarray,
sr: int,
highpass_freq: float = 2000.0,
threshold_db: float = -20.0,
) -> list[float]:
"""Detect impulsive sounds (antenna collisions) via spectral flux.
1. High-pass filter to separate from tonal beeps
2. Compute onset strength via spectral flux
3. Peak detection for sharp transients
"""
from scipy.signal import butter, sosfilt, find_peaks
# High-pass to isolate transients from tonal beeps
hp_freq = min(highpass_freq / (sr / 2), 0.99)
sos = butter(4, hp_freq, btype="highpass", output="sos")
filtered = sosfilt(sos, audio.astype(np.float64))
# Onset strength: short-term energy in 5ms windows
win_samples = max(1, int(0.005 * sr))
energy = np.array([
np.sum(filtered[i : i + win_samples] ** 2)
for i in range(0, len(filtered) - win_samples, win_samples)
])
if len(energy) < 2:
return []
# Spectral flux: positive differences in energy
flux = np.diff(energy)
flux = np.maximum(flux, 0)
peak_val = np.max(flux)
if peak_val < 1e-10:
return []
flux /= peak_val
threshold = 10 ** (threshold_db / 20)
# Find peaks with minimum 200ms separation
min_distance = max(1, int(0.2 * sr / win_samples))
peaks, _ = find_peaks(flux, height=threshold, distance=min_distance)
# Convert window indices to seconds
return [(p * win_samples) / sr for p in peaks]
def measure_sync_offsets(
beep_onsets: list[float],
collision_onsets: list[float],
max_match_distance: float = 0.5,
) -> dict:
"""Match each beep to its nearest collision and compute offsets.
Returns dict with pairs, mean/max/std offset in milliseconds.
Positive offset means collision came AFTER beep.
"""
pairs = []
remaining_collisions = list(collision_onsets)
for bt in beep_onsets:
if not remaining_collisions:
break
distances = [abs(ct - bt) for ct in remaining_collisions]
best_idx = int(np.argmin(distances))
if distances[best_idx] <= max_match_distance:
ct = remaining_collisions.pop(best_idx)
offset_ms = (ct - bt) * 1000.0
pairs.append((bt, ct, offset_ms))
offsets = [p[2] for p in pairs]
return {
"pairs": pairs,
"n_matched": len(pairs),
"n_beeps": len(beep_onsets),
"n_collisions": len(collision_onsets),
"mean_offset_ms": float(np.mean(offsets)) if offsets else float("nan"),
"max_offset_ms": float(np.max(np.abs(offsets))) if offsets else float("nan"),
"std_offset_ms": float(np.std(offsets)) if offsets else float("nan"),
}
# ──────── Mic Recording Helper ────────────────────────────────────
class MicRecorder:
"""Record from laptop microphone using sounddevice.
Usage:
recorder = MicRecorder(sr=48000)
recorder.start()
# ... do stuff ...
audio = recorder.stop() # returns np.ndarray
"""
def __init__(self, sr: int = 48000, channels: int = 1):
self.sr = sr
self.channels = channels
self._frames: list[np.ndarray] = []
self._stream = None
def start(self) -> None:
import sounddevice as sd
self._frames = []
def callback(indata, frames, time_info, status):
self._frames.append(indata.copy())
self._stream = sd.InputStream(
samplerate=self.sr,
channels=self.channels,
dtype="float32",
callback=callback,
)
self._stream.start()
def stop(self) -> np.ndarray:
if self._stream is not None:
self._stream.stop()
self._stream.close()
self._stream = None
if not self._frames:
return np.zeros(0, dtype=np.float32)
audio = np.concatenate(self._frames, axis=0)
# Return mono (first channel if multi-channel)
if audio.ndim > 1:
audio = audio[:, 0]
return audio