File size: 3,717 Bytes
c6cd0b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Matplotlib visualization for SELD demo.
Produces a 14-panel figure (1 waveform + 13 per-class event tracks).
Mirrors the canvas rendering logic in App.vue.
"""
import numpy as np
import matplotlib
matplotlib.use('Agg')  # headless rendering — must be set before pyplot import
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.figure import Figure
from typing import List, Dict

CLASS_NAMES = [
    "Female speech",
    "Male speech",
    "Clapping",
    "Telephone",
    "Laughter",
    "Domestic sounds",
    "Walk/Footsteps",
    "Door",
    "Music",
    "Musical instr.",
    "Water tap",
    "Bell",
    "Knock",
]


def build_seld_figure(
    waveform_mono: np.ndarray,
    sr: int,
    events: List[Dict],
    audio_duration: float,
    title: str = "",
) -> Figure:
    """Build the full SELD visualization figure.

    Args:
        waveform_mono: 1D np.ndarray, mono waveform (for display only).
        sr: sample rate.
        events: list of dicts with keys {"time", "class", "azi", "ele"}.
        audio_duration: total audio length in seconds (fixes x-axis).
        title: optional figure title (e.g. mode name).

    Returns:
        matplotlib.figure.Figure (caller is responsible for plt.close(fig)).
    """
    n_classes = len(CLASS_NAMES)
    n_panels = 1 + n_classes  # waveform + 13 class tracks

    fig, axes = plt.subplots(
        n_panels, 1,
        figsize=(12, 2.0 + n_classes * 0.9),
        sharex=True,
    )

    # ------------------------------------------------------------------
    # Panel 0: Waveform
    # ------------------------------------------------------------------
    ax_wave = axes[0]
    t = np.linspace(0, audio_duration, len(waveform_mono))
    # Downsample for speed (keep at most 4000 points)
    step = max(1, len(waveform_mono) // 4000)
    ax_wave.plot(t[::step], waveform_mono[::step] * 18, color='#3b82f6', linewidth=0.6)
    ax_wave.set_ylim(-1, 1)
    ax_wave.set_ylabel("Waveform", fontsize=7)
    ax_wave.set_yticks([])
    ax_wave.axhline(0, color='#94a3b8', linewidth=0.4)
    if title:
        ax_wave.set_title(title, fontsize=9, loc='left', pad=3)

    # ------------------------------------------------------------------
    # Panels 1–13: Per-class event tracks
    # ------------------------------------------------------------------
    # Group events by class for quick lookup
    events_by_class: Dict[int, List[Dict]] = {i: [] for i in range(n_classes)}
    for ev in events:
        c = ev["class"]
        if 0 <= c < n_classes:
            events_by_class[c].append(ev)

    for class_idx in range(n_classes):
        ax = axes[class_idx + 1]
        ax.set_ylim(-1, 1)
        ax.axhline(0, color='#cbd5e1', linewidth=0.5)
        ax.set_yticks([])

        # Class label on y-axis
        ax.set_ylabel(CLASS_NAMES[class_idx], fontsize=6, rotation=0,
                      labelpad=80, va='center')

        class_events = events_by_class[class_idx]
        for ev in class_events:
            t_ev = ev["time"]
            ax.axvline(x=t_ev, color='#22c55e', linewidth=1.2, alpha=0.85)

    # ------------------------------------------------------------------
    # Shared x-axis
    # ------------------------------------------------------------------
    axes[-1].set_xlabel("Time (s)", fontsize=8)
    axes[-1].set_xlim(0, audio_duration)

    fig.tight_layout(h_pad=0.15)
    return fig


def make_empty_figure(audio_duration: float, title: str = "") -> Figure:
    """Return a blank figure with correct axis limits (before inference starts)."""
    dummy_waveform = np.zeros(int(audio_duration * 100))
    return build_seld_figure(dummy_waveform, 100, [], audio_duration, title=title)