""" Audio quality metrics: SNR, STOI, PESQ calculation functions. Provides objective quality measurements for audio extraction validation. """ import logging from typing import Optional, Tuple import numpy as np logger = logging.getLogger(__name__) class QualityMetricsError(Exception): """Custom exception for quality metric calculation errors.""" pass def calculate_snr(clean_signal: np.ndarray, noisy_signal: np.ndarray) -> float: """ Calculate Signal-to-Noise Ratio (SNR) in dB. Measures the ratio of signal power to noise power. Higher values indicate cleaner audio. Args: clean_signal: Clean reference signal noisy_signal: Signal with noise Returns: SNR in dB Raises: QualityMetricsError: If signals have different lengths or calculation fails """ try: # Ensure same length min_len = min(len(clean_signal), len(noisy_signal)) clean_signal = clean_signal[:min_len] noisy_signal = noisy_signal[:min_len] # Calculate noise noise = noisy_signal - clean_signal # Calculate power signal_power = np.mean(clean_signal**2) noise_power = np.mean(noise**2) # Handle edge case: no noise if noise_power == 0: return float("inf") # Handle edge case: no signal if signal_power == 0: return float("-inf") # Calculate SNR in dB snr = 10 * np.log10(signal_power / noise_power) return snr except Exception as e: raise QualityMetricsError(f"Failed to calculate SNR: {str(e)}") def calculate_snr_segmental( signal: np.ndarray, sample_rate: int, frame_length_ms: int = 20 ) -> float: """ Calculate segmental SNR for signal without clean reference. Useful when you don't have a clean reference - estimates SNR by analyzing signal characteristics. Args: signal: Audio signal sample_rate: Sample rate in Hz frame_length_ms: Frame length in milliseconds Returns: Segmental SNR in dB """ try: frame_length = int(sample_rate * frame_length_ms / 1000) hop_length = frame_length // 2 snrs = [] for i in range(0, len(signal) - frame_length, hop_length): frame = signal[i : i + frame_length] signal_power = np.mean(frame**2) if signal_power > 0: snr_db = 10 * np.log10(signal_power) snrs.append(snr_db) if not snrs: return 0.0 return np.mean(snrs) except Exception as e: raise QualityMetricsError(f"Failed to calculate segmental SNR: {str(e)}") def calculate_stoi( clean_signal: np.ndarray, degraded_signal: np.ndarray, sample_rate: int, extended: bool = True ) -> float: """ Calculate Short-Time Objective Intelligibility (STOI) score. Measures speech intelligibility. Range: 0-1 (higher = better). Extended STOI (e-STOI) is better for intermediate quality levels. Args: clean_signal: Clean reference signal degraded_signal: Degraded signal to evaluate sample_rate: Sample rate in Hz extended: Use extended STOI (default: True) Returns: STOI score (0-1) Raises: QualityMetricsError: If calculation fails """ try: from pystoi import stoi # Ensure same length min_len = min(len(clean_signal), len(degraded_signal)) clean_signal = clean_signal[:min_len] degraded_signal = degraded_signal[:min_len] # Calculate STOI score = stoi(clean_signal, degraded_signal, sample_rate, extended=extended) return score except Exception as e: raise QualityMetricsError(f"Failed to calculate STOI: {str(e)}") def calculate_pesq( reference_signal: np.ndarray, degraded_signal: np.ndarray, sample_rate: int, mode: str = "wb" ) -> float: """ Calculate Perceptual Evaluation of Speech Quality (PESQ) score. Correlates with human perception of quality. Range: -0.5 to 4.5 (higher = better). Args: reference_signal: Reference (clean) signal degraded_signal: Degraded signal to evaluate sample_rate: Sample rate in Hz (must be 8000 or 16000) mode: 'wb' (wideband, 16kHz) or 'nb' (narrowband, 8kHz) Returns: PESQ score Raises: QualityMetricsError: If calculation fails or sample rate is invalid """ try: from pesq import pesq # Ensure same length min_len = min(len(reference_signal), len(degraded_signal)) reference_signal = reference_signal[:min_len] degraded_signal = degraded_signal[:min_len] # PESQ requires specific sample rates if mode == "wb" and sample_rate != 16000: raise QualityMetricsError( f"Wideband PESQ requires 16kHz sample rate, got {sample_rate}Hz. " "Resample before calling this function." ) elif mode == "nb" and sample_rate != 8000: raise QualityMetricsError( f"Narrowband PESQ requires 8kHz sample rate, got {sample_rate}Hz. " "Resample before calling this function." ) # Calculate PESQ score = pesq(sample_rate, reference_signal, degraded_signal, mode) return score except Exception as e: if isinstance(e, QualityMetricsError): raise raise QualityMetricsError(f"Failed to calculate PESQ: {str(e)}") def calculate_pesq_with_resampling( reference_signal: np.ndarray, degraded_signal: np.ndarray, sample_rate: int, mode: str = "wb" ) -> float: """ Calculate PESQ with automatic resampling to required sample rate. Args: reference_signal: Reference signal degraded_signal: Degraded signal sample_rate: Current sample rate mode: 'wb' (wideband, 16kHz) or 'nb' (narrowband, 8kHz) Returns: PESQ score """ try: from pesq import pesq from scipy.signal import resample # Ensure same length min_len = min(len(reference_signal), len(degraded_signal)) reference_signal = reference_signal[:min_len] degraded_signal = degraded_signal[:min_len] # Determine target sample rate target_sr = 16000 if mode == "wb" else 8000 # Resample if needed if sample_rate != target_sr: target_len = int(len(reference_signal) * target_sr / sample_rate) reference_signal = resample(reference_signal, target_len) degraded_signal = resample(degraded_signal, target_len) # Calculate PESQ score = pesq(target_sr, reference_signal, degraded_signal, mode) return score except Exception as e: raise QualityMetricsError(f"Failed to calculate PESQ with resampling: {str(e)}") def validate_extraction_quality( original_signal: np.ndarray, extracted_signal: np.ndarray, sample_rate: int, snr_threshold: float = 20.0, stoi_threshold: float = 0.75, pesq_threshold: float = 2.5, ) -> dict: """ Validate extraction quality against thresholds. Calculates all three metrics and checks if they meet minimum thresholds. Args: original_signal: Original (noisy) signal extracted_signal: Extracted (cleaned) signal sample_rate: Sample rate in Hz snr_threshold: Minimum SNR in dB (default: 20) stoi_threshold: Minimum STOI score (default: 0.75) pesq_threshold: Minimum PESQ score (default: 2.5) Returns: Dictionary with metrics and pass/fail status """ results = { "snr": None, "snr_pass": False, "stoi": None, "stoi_pass": False, "pesq": None, "pesq_pass": False, "overall_pass": False, } try: # Calculate SNR try: results["snr"] = calculate_snr(original_signal, extracted_signal) results["snr_pass"] = results["snr"] >= snr_threshold except Exception as e: logger.warning(f"SNR calculation failed: {e}") # Calculate STOI try: results["stoi"] = calculate_stoi( original_signal, extracted_signal, sample_rate, extended=True ) results["stoi_pass"] = results["stoi"] >= stoi_threshold except Exception as e: logger.warning(f"STOI calculation failed: {e}") # Calculate PESQ (with resampling if needed) try: results["pesq"] = calculate_pesq_with_resampling( original_signal, extracted_signal, sample_rate, mode="wb" ) results["pesq_pass"] = results["pesq"] >= pesq_threshold except Exception as e: logger.warning(f"PESQ calculation failed: {e}") # Overall pass if all metrics that were calculated passed results["overall_pass"] = ( results.get("snr_pass", False) and results.get("stoi_pass", False) and results.get("pesq_pass", False) ) except Exception as e: logger.error(f"Quality validation failed: {e}") return results def get_quality_label(metric_name: str, value: float) -> str: """ Get quality label for a metric value. Args: metric_name: Metric name ('snr', 'stoi', 'pesq') value: Metric value Returns: Quality label string """ if metric_name == "snr": if value > 40: return "Excellent" elif value > 30: return "Very Good" elif value > 20: return "Good" elif value > 10: return "Fair" else: return "Poor" elif metric_name == "stoi": if value > 0.9: return "Excellent" elif value > 0.8: return "Very Good" elif value > 0.7: return "Good" elif value > 0.6: return "Fair" else: return "Poor" elif metric_name == "pesq": if value > 3.5: return "Excellent" elif value > 3.0: return "Good" elif value > 2.5: return "Fair" elif value > 2.0: return "Poor" else: return "Bad" return "Unknown" def generate_quality_report(metrics: dict) -> str: """ Generate human-readable quality report. Args: metrics: Dictionary from validate_extraction_quality() Returns: Formatted report string """ report = ["=== Voice Extraction Quality Report ===", ""] # SNR if metrics["snr"] is not None: status = "PASS" if metrics["snr_pass"] else "FAIL" quality = get_quality_label("snr", metrics["snr"]) report.append(f"SNR: {metrics['snr']:.2f} dB [{status}] - {quality}") # STOI if metrics["stoi"] is not None: status = "PASS" if metrics["stoi_pass"] else "FAIL" quality = get_quality_label("stoi", metrics["stoi"]) report.append(f"STOI: {metrics['stoi']:.3f} [{status}] - {quality}") # PESQ if metrics["pesq"] is not None: status = "PASS" if metrics["pesq_pass"] else "FAIL" quality = get_quality_label("pesq", metrics["pesq"]) report.append(f"PESQ: {metrics['pesq']:.2f} [{status}] - {quality}") # Overall overall = "PASS" if metrics["overall_pass"] else "FAIL" report.append("") report.append(f"Overall Quality: [{overall}]") return "\n".join(report)