""" Matplotlib visualization for SELD demo. Produces a 14-panel figure (1 waveform + 13 per-class event tracks). Mirrors the canvas rendering logic in App.vue. """ import numpy as np import matplotlib matplotlib.use('Agg') # headless rendering — must be set before pyplot import import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib.figure import Figure from typing import List, Dict CLASS_NAMES = [ "Female speech", "Male speech", "Clapping", "Telephone", "Laughter", "Domestic sounds", "Walk/Footsteps", "Door", "Music", "Musical instr.", "Water tap", "Bell", "Knock", ] def build_seld_figure( waveform_mono: np.ndarray, sr: int, events: List[Dict], audio_duration: float, title: str = "", ) -> Figure: """Build the full SELD visualization figure. Args: waveform_mono: 1D np.ndarray, mono waveform (for display only). sr: sample rate. events: list of dicts with keys {"time", "class", "azi", "ele"}. audio_duration: total audio length in seconds (fixes x-axis). title: optional figure title (e.g. mode name). Returns: matplotlib.figure.Figure (caller is responsible for plt.close(fig)). """ n_classes = len(CLASS_NAMES) n_panels = 1 + n_classes # waveform + 13 class tracks fig, axes = plt.subplots( n_panels, 1, figsize=(12, 2.0 + n_classes * 0.9), sharex=True, ) # ------------------------------------------------------------------ # Panel 0: Waveform # ------------------------------------------------------------------ ax_wave = axes[0] t = np.linspace(0, audio_duration, len(waveform_mono)) # Downsample for speed (keep at most 4000 points) step = max(1, len(waveform_mono) // 4000) ax_wave.plot(t[::step], waveform_mono[::step] * 18, color='#3b82f6', linewidth=0.6) ax_wave.set_ylim(-1, 1) ax_wave.set_ylabel("Waveform", fontsize=7) ax_wave.set_yticks([]) ax_wave.axhline(0, color='#94a3b8', linewidth=0.4) if title: ax_wave.set_title(title, fontsize=9, loc='left', pad=3) # ------------------------------------------------------------------ # Panels 1–13: Per-class event tracks # ------------------------------------------------------------------ # Group events by class for quick lookup events_by_class: Dict[int, List[Dict]] = {i: [] for i in range(n_classes)} for ev in events: c = ev["class"] if 0 <= c < n_classes: events_by_class[c].append(ev) for class_idx in range(n_classes): ax = axes[class_idx + 1] ax.set_ylim(-1, 1) ax.axhline(0, color='#cbd5e1', linewidth=0.5) ax.set_yticks([]) # Class label on y-axis ax.set_ylabel(CLASS_NAMES[class_idx], fontsize=6, rotation=0, labelpad=80, va='center') class_events = events_by_class[class_idx] for ev in class_events: t_ev = ev["time"] ax.axvline(x=t_ev, color='#22c55e', linewidth=1.2, alpha=0.85) # ------------------------------------------------------------------ # Shared x-axis # ------------------------------------------------------------------ axes[-1].set_xlabel("Time (s)", fontsize=8) axes[-1].set_xlim(0, audio_duration) fig.tight_layout(h_pad=0.15) return fig def make_empty_figure(audio_duration: float, title: str = "") -> Figure: """Return a blank figure with correct axis limits (before inference starts).""" dummy_waveform = np.zeros(int(audio_duration * 100)) return build_seld_figure(dummy_waveform, 100, [], audio_duration, title=title)