File size: 3,489 Bytes
a9536c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# -*- coding: utf-8 -*-
"""Reference-based audio metrics for separation/cover evaluation."""
from __future__ import annotations

from typing import Mapping

import numpy as np


EPS = 1e-10


def _as_mono_float(audio: np.ndarray) -> np.ndarray:
    arr = np.asarray(audio, dtype=np.float64)
    if arr.ndim == 2:
        arr = np.mean(arr, axis=1)
    return arr.reshape(-1)


def _align_pair(reference: np.ndarray, estimate: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    ref = _as_mono_float(reference)
    est = _as_mono_float(estimate)
    n = min(ref.size, est.size)
    if n <= 0:
        raise ValueError("Audio metric received empty reference or estimate.")
    return ref[:n], est[:n]


def _power(audio: np.ndarray) -> float:
    arr = np.asarray(audio, dtype=np.float64).reshape(-1)
    return float(np.sum(arr * arr))


def _db_ratio(signal_power: float, noise_power: float) -> float:
    signal_power = max(float(signal_power), EPS)
    noise_power = max(float(noise_power), EPS)
    return float(10.0 * np.log10(signal_power / noise_power))


def signal_distortion_ratio(reference: np.ndarray, estimate: np.ndarray) -> float:
    """Scale-dependent SDR: 10 log10(||s||^2 / ||s - shat||^2)."""
    ref, est = _align_pair(reference, estimate)
    return _db_ratio(_power(ref), _power(ref - est))


def scale_invariant_signal_distortion_ratio(reference: np.ndarray, estimate: np.ndarray) -> float:
    """SI-SDR as used by modern source-separation literature."""
    ref, est = _align_pair(reference, estimate)
    ref = ref - float(np.mean(ref))
    est = est - float(np.mean(est))
    ref_power = _power(ref)
    if ref_power <= EPS:
        raise ValueError("SI-SDR reference is silent.")
    scale = float(np.dot(est, ref) / (ref_power + EPS))
    target = scale * ref
    residual = est - target
    return _db_ratio(_power(target), _power(residual))


def signal_to_noise_ratio(reference: np.ndarray, estimate: np.ndarray) -> float:
    """Alias for scale-dependent reconstruction SNR."""
    return signal_distortion_ratio(reference, estimate)


def evaluate_reference_stems(

    references: Mapping[str, np.ndarray],

    estimates: Mapping[str, np.ndarray],

) -> dict:
    """Compute true reference-based metrics for matching stems.



    The caller must provide time-aligned reference stems. Without references,

    SI-SDR/SDR cannot be interpreted as source-separation quality.

    """
    stem_metrics: dict[str, dict[str, float]] = {}
    for stem_name, reference_audio in references.items():
        if stem_name not in estimates:
            raise KeyError(f"Missing estimated stem for reference: {stem_name}")
        estimate_audio = estimates[stem_name]
        stem_metrics[stem_name] = {
            "si_sdr": scale_invariant_signal_distortion_ratio(reference_audio, estimate_audio),
            "sdr": signal_distortion_ratio(reference_audio, estimate_audio),
            "snr": signal_to_noise_ratio(reference_audio, estimate_audio),
        }

    if not stem_metrics:
        raise ValueError("No reference stems were provided.")

    return {
        "mean_si_sdr": float(np.mean([metrics["si_sdr"] for metrics in stem_metrics.values()])),
        "mean_sdr": float(np.mean([metrics["sdr"] for metrics in stem_metrics.values()])),
        "mean_snr": float(np.mean([metrics["snr"] for metrics in stem_metrics.values()])),
        "stems": stem_metrics,
    }