SELD_Streaming_RC / inference /visualization.py
dongyuxuan
Initial HF Space with Xet audio assets
c6cd0b7
"""
Matplotlib visualization for SELD demo.
Produces a 14-panel figure (1 waveform + 13 per-class event tracks).
Mirrors the canvas rendering logic in App.vue.
"""
import numpy as np
import matplotlib
matplotlib.use('Agg') # headless rendering — must be set before pyplot import
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.figure import Figure
from typing import List, Dict
CLASS_NAMES = [
"Female speech",
"Male speech",
"Clapping",
"Telephone",
"Laughter",
"Domestic sounds",
"Walk/Footsteps",
"Door",
"Music",
"Musical instr.",
"Water tap",
"Bell",
"Knock",
]
def build_seld_figure(
waveform_mono: np.ndarray,
sr: int,
events: List[Dict],
audio_duration: float,
title: str = "",
) -> Figure:
"""Build the full SELD visualization figure.
Args:
waveform_mono: 1D np.ndarray, mono waveform (for display only).
sr: sample rate.
events: list of dicts with keys {"time", "class", "azi", "ele"}.
audio_duration: total audio length in seconds (fixes x-axis).
title: optional figure title (e.g. mode name).
Returns:
matplotlib.figure.Figure (caller is responsible for plt.close(fig)).
"""
n_classes = len(CLASS_NAMES)
n_panels = 1 + n_classes # waveform + 13 class tracks
fig, axes = plt.subplots(
n_panels, 1,
figsize=(12, 2.0 + n_classes * 0.9),
sharex=True,
)
# ------------------------------------------------------------------
# Panel 0: Waveform
# ------------------------------------------------------------------
ax_wave = axes[0]
t = np.linspace(0, audio_duration, len(waveform_mono))
# Downsample for speed (keep at most 4000 points)
step = max(1, len(waveform_mono) // 4000)
ax_wave.plot(t[::step], waveform_mono[::step] * 18, color='#3b82f6', linewidth=0.6)
ax_wave.set_ylim(-1, 1)
ax_wave.set_ylabel("Waveform", fontsize=7)
ax_wave.set_yticks([])
ax_wave.axhline(0, color='#94a3b8', linewidth=0.4)
if title:
ax_wave.set_title(title, fontsize=9, loc='left', pad=3)
# ------------------------------------------------------------------
# Panels 1–13: Per-class event tracks
# ------------------------------------------------------------------
# Group events by class for quick lookup
events_by_class: Dict[int, List[Dict]] = {i: [] for i in range(n_classes)}
for ev in events:
c = ev["class"]
if 0 <= c < n_classes:
events_by_class[c].append(ev)
for class_idx in range(n_classes):
ax = axes[class_idx + 1]
ax.set_ylim(-1, 1)
ax.axhline(0, color='#cbd5e1', linewidth=0.5)
ax.set_yticks([])
# Class label on y-axis
ax.set_ylabel(CLASS_NAMES[class_idx], fontsize=6, rotation=0,
labelpad=80, va='center')
class_events = events_by_class[class_idx]
for ev in class_events:
t_ev = ev["time"]
ax.axvline(x=t_ev, color='#22c55e', linewidth=1.2, alpha=0.85)
# ------------------------------------------------------------------
# Shared x-axis
# ------------------------------------------------------------------
axes[-1].set_xlabel("Time (s)", fontsize=8)
axes[-1].set_xlim(0, audio_duration)
fig.tight_layout(h_pad=0.15)
return fig
def make_empty_figure(audio_duration: float, title: str = "") -> Figure:
"""Return a blank figure with correct axis limits (before inference starts)."""
dummy_waveform = np.zeros(int(audio_duration * 100))
return build_seld_figure(dummy_waveform, 100, [], audio_duration, title=title)