voice-tools / src /lib /quality_metrics.py
jcudit's picture
jcudit HF Staff
fix: also correct lib/ in gitignore to only exclude root-level, add src/lib package
3ff2f18
"""
Audio quality metrics: SNR, STOI, PESQ calculation functions.
Provides objective quality measurements for audio extraction validation.
"""
import logging
from typing import Optional, Tuple
import numpy as np
logger = logging.getLogger(__name__)
class QualityMetricsError(Exception):
"""Custom exception for quality metric calculation errors."""
pass
def calculate_snr(clean_signal: np.ndarray, noisy_signal: np.ndarray) -> float:
"""
Calculate Signal-to-Noise Ratio (SNR) in dB.
Measures the ratio of signal power to noise power.
Higher values indicate cleaner audio.
Args:
clean_signal: Clean reference signal
noisy_signal: Signal with noise
Returns:
SNR in dB
Raises:
QualityMetricsError: If signals have different lengths or calculation fails
"""
try:
# Ensure same length
min_len = min(len(clean_signal), len(noisy_signal))
clean_signal = clean_signal[:min_len]
noisy_signal = noisy_signal[:min_len]
# Calculate noise
noise = noisy_signal - clean_signal
# Calculate power
signal_power = np.mean(clean_signal**2)
noise_power = np.mean(noise**2)
# Handle edge case: no noise
if noise_power == 0:
return float("inf")
# Handle edge case: no signal
if signal_power == 0:
return float("-inf")
# Calculate SNR in dB
snr = 10 * np.log10(signal_power / noise_power)
return snr
except Exception as e:
raise QualityMetricsError(f"Failed to calculate SNR: {str(e)}")
def calculate_snr_segmental(
signal: np.ndarray, sample_rate: int, frame_length_ms: int = 20
) -> float:
"""
Calculate segmental SNR for signal without clean reference.
Useful when you don't have a clean reference - estimates SNR
by analyzing signal characteristics.
Args:
signal: Audio signal
sample_rate: Sample rate in Hz
frame_length_ms: Frame length in milliseconds
Returns:
Segmental SNR in dB
"""
try:
frame_length = int(sample_rate * frame_length_ms / 1000)
hop_length = frame_length // 2
snrs = []
for i in range(0, len(signal) - frame_length, hop_length):
frame = signal[i : i + frame_length]
signal_power = np.mean(frame**2)
if signal_power > 0:
snr_db = 10 * np.log10(signal_power)
snrs.append(snr_db)
if not snrs:
return 0.0
return np.mean(snrs)
except Exception as e:
raise QualityMetricsError(f"Failed to calculate segmental SNR: {str(e)}")
def calculate_stoi(
clean_signal: np.ndarray, degraded_signal: np.ndarray, sample_rate: int, extended: bool = True
) -> float:
"""
Calculate Short-Time Objective Intelligibility (STOI) score.
Measures speech intelligibility. Range: 0-1 (higher = better).
Extended STOI (e-STOI) is better for intermediate quality levels.
Args:
clean_signal: Clean reference signal
degraded_signal: Degraded signal to evaluate
sample_rate: Sample rate in Hz
extended: Use extended STOI (default: True)
Returns:
STOI score (0-1)
Raises:
QualityMetricsError: If calculation fails
"""
try:
from pystoi import stoi
# Ensure same length
min_len = min(len(clean_signal), len(degraded_signal))
clean_signal = clean_signal[:min_len]
degraded_signal = degraded_signal[:min_len]
# Calculate STOI
score = stoi(clean_signal, degraded_signal, sample_rate, extended=extended)
return score
except Exception as e:
raise QualityMetricsError(f"Failed to calculate STOI: {str(e)}")
def calculate_pesq(
reference_signal: np.ndarray, degraded_signal: np.ndarray, sample_rate: int, mode: str = "wb"
) -> float:
"""
Calculate Perceptual Evaluation of Speech Quality (PESQ) score.
Correlates with human perception of quality. Range: -0.5 to 4.5 (higher = better).
Args:
reference_signal: Reference (clean) signal
degraded_signal: Degraded signal to evaluate
sample_rate: Sample rate in Hz (must be 8000 or 16000)
mode: 'wb' (wideband, 16kHz) or 'nb' (narrowband, 8kHz)
Returns:
PESQ score
Raises:
QualityMetricsError: If calculation fails or sample rate is invalid
"""
try:
from pesq import pesq
# Ensure same length
min_len = min(len(reference_signal), len(degraded_signal))
reference_signal = reference_signal[:min_len]
degraded_signal = degraded_signal[:min_len]
# PESQ requires specific sample rates
if mode == "wb" and sample_rate != 16000:
raise QualityMetricsError(
f"Wideband PESQ requires 16kHz sample rate, got {sample_rate}Hz. "
"Resample before calling this function."
)
elif mode == "nb" and sample_rate != 8000:
raise QualityMetricsError(
f"Narrowband PESQ requires 8kHz sample rate, got {sample_rate}Hz. "
"Resample before calling this function."
)
# Calculate PESQ
score = pesq(sample_rate, reference_signal, degraded_signal, mode)
return score
except Exception as e:
if isinstance(e, QualityMetricsError):
raise
raise QualityMetricsError(f"Failed to calculate PESQ: {str(e)}")
def calculate_pesq_with_resampling(
reference_signal: np.ndarray, degraded_signal: np.ndarray, sample_rate: int, mode: str = "wb"
) -> float:
"""
Calculate PESQ with automatic resampling to required sample rate.
Args:
reference_signal: Reference signal
degraded_signal: Degraded signal
sample_rate: Current sample rate
mode: 'wb' (wideband, 16kHz) or 'nb' (narrowband, 8kHz)
Returns:
PESQ score
"""
try:
from pesq import pesq
from scipy.signal import resample
# Ensure same length
min_len = min(len(reference_signal), len(degraded_signal))
reference_signal = reference_signal[:min_len]
degraded_signal = degraded_signal[:min_len]
# Determine target sample rate
target_sr = 16000 if mode == "wb" else 8000
# Resample if needed
if sample_rate != target_sr:
target_len = int(len(reference_signal) * target_sr / sample_rate)
reference_signal = resample(reference_signal, target_len)
degraded_signal = resample(degraded_signal, target_len)
# Calculate PESQ
score = pesq(target_sr, reference_signal, degraded_signal, mode)
return score
except Exception as e:
raise QualityMetricsError(f"Failed to calculate PESQ with resampling: {str(e)}")
def validate_extraction_quality(
original_signal: np.ndarray,
extracted_signal: np.ndarray,
sample_rate: int,
snr_threshold: float = 20.0,
stoi_threshold: float = 0.75,
pesq_threshold: float = 2.5,
) -> dict:
"""
Validate extraction quality against thresholds.
Calculates all three metrics and checks if they meet minimum thresholds.
Args:
original_signal: Original (noisy) signal
extracted_signal: Extracted (cleaned) signal
sample_rate: Sample rate in Hz
snr_threshold: Minimum SNR in dB (default: 20)
stoi_threshold: Minimum STOI score (default: 0.75)
pesq_threshold: Minimum PESQ score (default: 2.5)
Returns:
Dictionary with metrics and pass/fail status
"""
results = {
"snr": None,
"snr_pass": False,
"stoi": None,
"stoi_pass": False,
"pesq": None,
"pesq_pass": False,
"overall_pass": False,
}
try:
# Calculate SNR
try:
results["snr"] = calculate_snr(original_signal, extracted_signal)
results["snr_pass"] = results["snr"] >= snr_threshold
except Exception as e:
logger.warning(f"SNR calculation failed: {e}")
# Calculate STOI
try:
results["stoi"] = calculate_stoi(
original_signal, extracted_signal, sample_rate, extended=True
)
results["stoi_pass"] = results["stoi"] >= stoi_threshold
except Exception as e:
logger.warning(f"STOI calculation failed: {e}")
# Calculate PESQ (with resampling if needed)
try:
results["pesq"] = calculate_pesq_with_resampling(
original_signal, extracted_signal, sample_rate, mode="wb"
)
results["pesq_pass"] = results["pesq"] >= pesq_threshold
except Exception as e:
logger.warning(f"PESQ calculation failed: {e}")
# Overall pass if all metrics that were calculated passed
results["overall_pass"] = (
results.get("snr_pass", False)
and results.get("stoi_pass", False)
and results.get("pesq_pass", False)
)
except Exception as e:
logger.error(f"Quality validation failed: {e}")
return results
def get_quality_label(metric_name: str, value: float) -> str:
"""
Get quality label for a metric value.
Args:
metric_name: Metric name ('snr', 'stoi', 'pesq')
value: Metric value
Returns:
Quality label string
"""
if metric_name == "snr":
if value > 40:
return "Excellent"
elif value > 30:
return "Very Good"
elif value > 20:
return "Good"
elif value > 10:
return "Fair"
else:
return "Poor"
elif metric_name == "stoi":
if value > 0.9:
return "Excellent"
elif value > 0.8:
return "Very Good"
elif value > 0.7:
return "Good"
elif value > 0.6:
return "Fair"
else:
return "Poor"
elif metric_name == "pesq":
if value > 3.5:
return "Excellent"
elif value > 3.0:
return "Good"
elif value > 2.5:
return "Fair"
elif value > 2.0:
return "Poor"
else:
return "Bad"
return "Unknown"
def generate_quality_report(metrics: dict) -> str:
"""
Generate human-readable quality report.
Args:
metrics: Dictionary from validate_extraction_quality()
Returns:
Formatted report string
"""
report = ["=== Voice Extraction Quality Report ===", ""]
# SNR
if metrics["snr"] is not None:
status = "PASS" if metrics["snr_pass"] else "FAIL"
quality = get_quality_label("snr", metrics["snr"])
report.append(f"SNR: {metrics['snr']:.2f} dB [{status}] - {quality}")
# STOI
if metrics["stoi"] is not None:
status = "PASS" if metrics["stoi_pass"] else "FAIL"
quality = get_quality_label("stoi", metrics["stoi"])
report.append(f"STOI: {metrics['stoi']:.3f} [{status}] - {quality}")
# PESQ
if metrics["pesq"] is not None:
status = "PASS" if metrics["pesq_pass"] else "FAIL"
quality = get_quality_label("pesq", metrics["pesq"])
report.append(f"PESQ: {metrics['pesq']:.2f} [{status}] - {quality}")
# Overall
overall = "PASS" if metrics["overall_pass"] else "FAIL"
report.append("")
report.append(f"Overall Quality: [{overall}]")
return "\n".join(report)