Spaces:
Running on Zero
Running on Zero
| """ | |
| Audio quality metrics: SNR, STOI, PESQ calculation functions. | |
| Provides objective quality measurements for audio extraction validation. | |
| """ | |
| import logging | |
| from typing import Optional, Tuple | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| class QualityMetricsError(Exception): | |
| """Custom exception for quality metric calculation errors.""" | |
| pass | |
| def calculate_snr(clean_signal: np.ndarray, noisy_signal: np.ndarray) -> float: | |
| """ | |
| Calculate Signal-to-Noise Ratio (SNR) in dB. | |
| Measures the ratio of signal power to noise power. | |
| Higher values indicate cleaner audio. | |
| Args: | |
| clean_signal: Clean reference signal | |
| noisy_signal: Signal with noise | |
| Returns: | |
| SNR in dB | |
| Raises: | |
| QualityMetricsError: If signals have different lengths or calculation fails | |
| """ | |
| try: | |
| # Ensure same length | |
| min_len = min(len(clean_signal), len(noisy_signal)) | |
| clean_signal = clean_signal[:min_len] | |
| noisy_signal = noisy_signal[:min_len] | |
| # Calculate noise | |
| noise = noisy_signal - clean_signal | |
| # Calculate power | |
| signal_power = np.mean(clean_signal**2) | |
| noise_power = np.mean(noise**2) | |
| # Handle edge case: no noise | |
| if noise_power == 0: | |
| return float("inf") | |
| # Handle edge case: no signal | |
| if signal_power == 0: | |
| return float("-inf") | |
| # Calculate SNR in dB | |
| snr = 10 * np.log10(signal_power / noise_power) | |
| return snr | |
| except Exception as e: | |
| raise QualityMetricsError(f"Failed to calculate SNR: {str(e)}") | |
| def calculate_snr_segmental( | |
| signal: np.ndarray, sample_rate: int, frame_length_ms: int = 20 | |
| ) -> float: | |
| """ | |
| Calculate segmental SNR for signal without clean reference. | |
| Useful when you don't have a clean reference - estimates SNR | |
| by analyzing signal characteristics. | |
| Args: | |
| signal: Audio signal | |
| sample_rate: Sample rate in Hz | |
| frame_length_ms: Frame length in milliseconds | |
| Returns: | |
| Segmental SNR in dB | |
| """ | |
| try: | |
| frame_length = int(sample_rate * frame_length_ms / 1000) | |
| hop_length = frame_length // 2 | |
| snrs = [] | |
| for i in range(0, len(signal) - frame_length, hop_length): | |
| frame = signal[i : i + frame_length] | |
| signal_power = np.mean(frame**2) | |
| if signal_power > 0: | |
| snr_db = 10 * np.log10(signal_power) | |
| snrs.append(snr_db) | |
| if not snrs: | |
| return 0.0 | |
| return np.mean(snrs) | |
| except Exception as e: | |
| raise QualityMetricsError(f"Failed to calculate segmental SNR: {str(e)}") | |
| def calculate_stoi( | |
| clean_signal: np.ndarray, degraded_signal: np.ndarray, sample_rate: int, extended: bool = True | |
| ) -> float: | |
| """ | |
| Calculate Short-Time Objective Intelligibility (STOI) score. | |
| Measures speech intelligibility. Range: 0-1 (higher = better). | |
| Extended STOI (e-STOI) is better for intermediate quality levels. | |
| Args: | |
| clean_signal: Clean reference signal | |
| degraded_signal: Degraded signal to evaluate | |
| sample_rate: Sample rate in Hz | |
| extended: Use extended STOI (default: True) | |
| Returns: | |
| STOI score (0-1) | |
| Raises: | |
| QualityMetricsError: If calculation fails | |
| """ | |
| try: | |
| from pystoi import stoi | |
| # Ensure same length | |
| min_len = min(len(clean_signal), len(degraded_signal)) | |
| clean_signal = clean_signal[:min_len] | |
| degraded_signal = degraded_signal[:min_len] | |
| # Calculate STOI | |
| score = stoi(clean_signal, degraded_signal, sample_rate, extended=extended) | |
| return score | |
| except Exception as e: | |
| raise QualityMetricsError(f"Failed to calculate STOI: {str(e)}") | |
| def calculate_pesq( | |
| reference_signal: np.ndarray, degraded_signal: np.ndarray, sample_rate: int, mode: str = "wb" | |
| ) -> float: | |
| """ | |
| Calculate Perceptual Evaluation of Speech Quality (PESQ) score. | |
| Correlates with human perception of quality. Range: -0.5 to 4.5 (higher = better). | |
| Args: | |
| reference_signal: Reference (clean) signal | |
| degraded_signal: Degraded signal to evaluate | |
| sample_rate: Sample rate in Hz (must be 8000 or 16000) | |
| mode: 'wb' (wideband, 16kHz) or 'nb' (narrowband, 8kHz) | |
| Returns: | |
| PESQ score | |
| Raises: | |
| QualityMetricsError: If calculation fails or sample rate is invalid | |
| """ | |
| try: | |
| from pesq import pesq | |
| # Ensure same length | |
| min_len = min(len(reference_signal), len(degraded_signal)) | |
| reference_signal = reference_signal[:min_len] | |
| degraded_signal = degraded_signal[:min_len] | |
| # PESQ requires specific sample rates | |
| if mode == "wb" and sample_rate != 16000: | |
| raise QualityMetricsError( | |
| f"Wideband PESQ requires 16kHz sample rate, got {sample_rate}Hz. " | |
| "Resample before calling this function." | |
| ) | |
| elif mode == "nb" and sample_rate != 8000: | |
| raise QualityMetricsError( | |
| f"Narrowband PESQ requires 8kHz sample rate, got {sample_rate}Hz. " | |
| "Resample before calling this function." | |
| ) | |
| # Calculate PESQ | |
| score = pesq(sample_rate, reference_signal, degraded_signal, mode) | |
| return score | |
| except Exception as e: | |
| if isinstance(e, QualityMetricsError): | |
| raise | |
| raise QualityMetricsError(f"Failed to calculate PESQ: {str(e)}") | |
| def calculate_pesq_with_resampling( | |
| reference_signal: np.ndarray, degraded_signal: np.ndarray, sample_rate: int, mode: str = "wb" | |
| ) -> float: | |
| """ | |
| Calculate PESQ with automatic resampling to required sample rate. | |
| Args: | |
| reference_signal: Reference signal | |
| degraded_signal: Degraded signal | |
| sample_rate: Current sample rate | |
| mode: 'wb' (wideband, 16kHz) or 'nb' (narrowband, 8kHz) | |
| Returns: | |
| PESQ score | |
| """ | |
| try: | |
| from pesq import pesq | |
| from scipy.signal import resample | |
| # Ensure same length | |
| min_len = min(len(reference_signal), len(degraded_signal)) | |
| reference_signal = reference_signal[:min_len] | |
| degraded_signal = degraded_signal[:min_len] | |
| # Determine target sample rate | |
| target_sr = 16000 if mode == "wb" else 8000 | |
| # Resample if needed | |
| if sample_rate != target_sr: | |
| target_len = int(len(reference_signal) * target_sr / sample_rate) | |
| reference_signal = resample(reference_signal, target_len) | |
| degraded_signal = resample(degraded_signal, target_len) | |
| # Calculate PESQ | |
| score = pesq(target_sr, reference_signal, degraded_signal, mode) | |
| return score | |
| except Exception as e: | |
| raise QualityMetricsError(f"Failed to calculate PESQ with resampling: {str(e)}") | |
| def validate_extraction_quality( | |
| original_signal: np.ndarray, | |
| extracted_signal: np.ndarray, | |
| sample_rate: int, | |
| snr_threshold: float = 20.0, | |
| stoi_threshold: float = 0.75, | |
| pesq_threshold: float = 2.5, | |
| ) -> dict: | |
| """ | |
| Validate extraction quality against thresholds. | |
| Calculates all three metrics and checks if they meet minimum thresholds. | |
| Args: | |
| original_signal: Original (noisy) signal | |
| extracted_signal: Extracted (cleaned) signal | |
| sample_rate: Sample rate in Hz | |
| snr_threshold: Minimum SNR in dB (default: 20) | |
| stoi_threshold: Minimum STOI score (default: 0.75) | |
| pesq_threshold: Minimum PESQ score (default: 2.5) | |
| Returns: | |
| Dictionary with metrics and pass/fail status | |
| """ | |
| results = { | |
| "snr": None, | |
| "snr_pass": False, | |
| "stoi": None, | |
| "stoi_pass": False, | |
| "pesq": None, | |
| "pesq_pass": False, | |
| "overall_pass": False, | |
| } | |
| try: | |
| # Calculate SNR | |
| try: | |
| results["snr"] = calculate_snr(original_signal, extracted_signal) | |
| results["snr_pass"] = results["snr"] >= snr_threshold | |
| except Exception as e: | |
| logger.warning(f"SNR calculation failed: {e}") | |
| # Calculate STOI | |
| try: | |
| results["stoi"] = calculate_stoi( | |
| original_signal, extracted_signal, sample_rate, extended=True | |
| ) | |
| results["stoi_pass"] = results["stoi"] >= stoi_threshold | |
| except Exception as e: | |
| logger.warning(f"STOI calculation failed: {e}") | |
| # Calculate PESQ (with resampling if needed) | |
| try: | |
| results["pesq"] = calculate_pesq_with_resampling( | |
| original_signal, extracted_signal, sample_rate, mode="wb" | |
| ) | |
| results["pesq_pass"] = results["pesq"] >= pesq_threshold | |
| except Exception as e: | |
| logger.warning(f"PESQ calculation failed: {e}") | |
| # Overall pass if all metrics that were calculated passed | |
| results["overall_pass"] = ( | |
| results.get("snr_pass", False) | |
| and results.get("stoi_pass", False) | |
| and results.get("pesq_pass", False) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Quality validation failed: {e}") | |
| return results | |
| def get_quality_label(metric_name: str, value: float) -> str: | |
| """ | |
| Get quality label for a metric value. | |
| Args: | |
| metric_name: Metric name ('snr', 'stoi', 'pesq') | |
| value: Metric value | |
| Returns: | |
| Quality label string | |
| """ | |
| if metric_name == "snr": | |
| if value > 40: | |
| return "Excellent" | |
| elif value > 30: | |
| return "Very Good" | |
| elif value > 20: | |
| return "Good" | |
| elif value > 10: | |
| return "Fair" | |
| else: | |
| return "Poor" | |
| elif metric_name == "stoi": | |
| if value > 0.9: | |
| return "Excellent" | |
| elif value > 0.8: | |
| return "Very Good" | |
| elif value > 0.7: | |
| return "Good" | |
| elif value > 0.6: | |
| return "Fair" | |
| else: | |
| return "Poor" | |
| elif metric_name == "pesq": | |
| if value > 3.5: | |
| return "Excellent" | |
| elif value > 3.0: | |
| return "Good" | |
| elif value > 2.5: | |
| return "Fair" | |
| elif value > 2.0: | |
| return "Poor" | |
| else: | |
| return "Bad" | |
| return "Unknown" | |
| def generate_quality_report(metrics: dict) -> str: | |
| """ | |
| Generate human-readable quality report. | |
| Args: | |
| metrics: Dictionary from validate_extraction_quality() | |
| Returns: | |
| Formatted report string | |
| """ | |
| report = ["=== Voice Extraction Quality Report ===", ""] | |
| # SNR | |
| if metrics["snr"] is not None: | |
| status = "PASS" if metrics["snr_pass"] else "FAIL" | |
| quality = get_quality_label("snr", metrics["snr"]) | |
| report.append(f"SNR: {metrics['snr']:.2f} dB [{status}] - {quality}") | |
| # STOI | |
| if metrics["stoi"] is not None: | |
| status = "PASS" if metrics["stoi_pass"] else "FAIL" | |
| quality = get_quality_label("stoi", metrics["stoi"]) | |
| report.append(f"STOI: {metrics['stoi']:.3f} [{status}] - {quality}") | |
| # PESQ | |
| if metrics["pesq"] is not None: | |
| status = "PASS" if metrics["pesq_pass"] else "FAIL" | |
| quality = get_quality_label("pesq", metrics["pesq"]) | |
| report.append(f"PESQ: {metrics['pesq']:.2f} [{status}] - {quality}") | |
| # Overall | |
| overall = "PASS" if metrics["overall_pass"] else "FAIL" | |
| report.append("") | |
| report.append(f"Overall Quality: [{overall}]") | |
| return "\n".join(report) | |