Spaces:
Runtime error
Runtime error
| """ | |
| MOS (Mean Opinion Score) Predictor Module | |
| Automated quality assessment for synthesized speech | |
| """ | |
| import torch | |
| import numpy as np | |
| import librosa | |
| from pathlib import Path | |
| from typing import Union, Optional | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| try: | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification | |
| except ImportError: | |
| print("Warning: transformers not installed. Run: pip install transformers") | |
| Wav2Vec2Processor = None | |
| Wav2Vec2ForSequenceClassification = None | |
| class MOSPredictor: | |
| """ | |
| Mean Opinion Score (MOS) prediction for speech quality assessment | |
| Predicts human-perceived naturalness on a 1-5 scale: | |
| - 5: Excellent (natural, no artifacts) | |
| - 4: Good (minor artifacts) | |
| - 3: Fair (noticeable artifacts) | |
| - 2: Poor (significant artifacts) | |
| - 1: Bad (unintelligible) | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "microsoft/wavlm-base-plus", | |
| device: str = "cuda" | |
| ): | |
| """ | |
| Initialize MOS Predictor | |
| Args: | |
| model_name: Pre-trained model for quality assessment | |
| device: Device to run on ('cuda' or 'cpu') | |
| """ | |
| self.device = device if torch.cuda.is_available() else "cpu" | |
| self.model_name = model_name | |
| print(f"📊 Initializing MOS Predictor on {self.device}...") | |
| # Use heuristic-based quality assessment (no model needed) | |
| # For production, consider NISQA or fine-tuned models | |
| self.processor = None | |
| self.model = None | |
| print("✓ MOS Predictor initialized!") | |
| print(" Using heuristic-based quality assessment") | |
| print(" For production, consider NISQA or fine-tuned models") | |
| def predict( | |
| self, | |
| audio_path: Union[str, Path], | |
| return_details: bool = False | |
| ) -> Union[float, dict]: | |
| """ | |
| Predict MOS score for audio file | |
| Args: | |
| audio_path: Path to audio file | |
| return_details: Return detailed quality metrics | |
| Returns: | |
| MOS score (1-5) or dict with detailed metrics | |
| """ | |
| audio_path = Path(audio_path) | |
| if not audio_path.exists(): | |
| raise FileNotFoundError(f"Audio file not found: {audio_path}") | |
| try: | |
| # Load audio | |
| audio, sr = librosa.load(str(audio_path), sr=16000) | |
| # Compute quality metrics | |
| metrics = self._compute_quality_metrics(audio, sr) | |
| # Estimate MOS score (heuristic-based) | |
| mos_score = self._estimate_mos(metrics) | |
| if return_details: | |
| return { | |
| "mos_score": mos_score, | |
| "metrics": metrics, | |
| "quality_level": self._get_quality_level(mos_score) | |
| } | |
| else: | |
| return mos_score | |
| except Exception as e: | |
| print(f"❌ Error predicting MOS for {audio_path.name}: {e}") | |
| raise | |
| def predict_batch( | |
| self, | |
| audio_paths: list, | |
| return_details: bool = False | |
| ) -> list: | |
| """ | |
| Predict MOS scores for multiple audio files | |
| Args: | |
| audio_paths: List of audio file paths | |
| return_details: Return detailed metrics | |
| Returns: | |
| List of MOS scores or detailed dicts | |
| """ | |
| results = [] | |
| print(f"📊 Predicting MOS for {len(audio_paths)} files...") | |
| for audio_path in audio_paths: | |
| try: | |
| result = self.predict(audio_path, return_details=return_details) | |
| results.append(result) | |
| if not return_details: | |
| print(f" {Path(audio_path).name}: MOS = {result:.2f}") | |
| except Exception as e: | |
| print(f"⚠️ Skipping {audio_path}: {e}") | |
| results.append(None) | |
| return results | |
| def _compute_quality_metrics( | |
| self, | |
| audio: np.ndarray, | |
| sr: int | |
| ) -> dict: | |
| """ | |
| Compute audio quality metrics | |
| Args: | |
| audio: Audio array | |
| sr: Sample rate | |
| Returns: | |
| Dict of quality metrics | |
| """ | |
| metrics = {} | |
| # 1. Signal-to-Noise Ratio (SNR) estimation | |
| # Estimate noise floor from silent regions | |
| energy = librosa.feature.rms(y=audio)[0] | |
| noise_threshold = np.percentile(energy, 10) | |
| signal_threshold = np.percentile(energy, 90) | |
| snr_estimate = 20 * np.log10((signal_threshold + 1e-8) / (noise_threshold + 1e-8)) | |
| metrics["snr_db"] = float(snr_estimate) | |
| # 2. Spectral Flatness (measure of tonality vs noise) | |
| spectral_flatness = librosa.feature.spectral_flatness(y=audio) | |
| metrics["spectral_flatness"] = float(np.mean(spectral_flatness)) | |
| # 3. Zero Crossing Rate (measure of noisiness) | |
| zcr = librosa.feature.zero_crossing_rate(audio) | |
| metrics["zero_crossing_rate"] = float(np.mean(zcr)) | |
| # 4. Spectral Centroid (brightness) | |
| spectral_centroid = librosa.feature.spectral_centroid(y=audio, sr=sr) | |
| metrics["spectral_centroid"] = float(np.mean(spectral_centroid)) | |
| # 5. RMS Energy (overall loudness) | |
| rms = librosa.feature.rms(y=audio) | |
| metrics["rms_energy"] = float(np.mean(rms)) | |
| # 6. Clipping detection | |
| clipping_ratio = np.sum(np.abs(audio) > 0.99) / len(audio) | |
| metrics["clipping_ratio"] = float(clipping_ratio) | |
| # 7. Dynamic range | |
| dynamic_range = 20 * np.log10((np.max(np.abs(audio)) + 1e-8) / (np.mean(np.abs(audio)) + 1e-8)) | |
| metrics["dynamic_range_db"] = float(dynamic_range) | |
| return metrics | |
| def _estimate_mos(self, metrics: dict) -> float: | |
| """ | |
| Estimate MOS score from quality metrics (heuristic-based) | |
| Args: | |
| metrics: Quality metrics dict | |
| Returns: | |
| Estimated MOS score (1-5) | |
| """ | |
| score = 5.0 # Start with perfect score | |
| # Penalize low SNR | |
| if metrics["snr_db"] < 20: | |
| score -= (20 - metrics["snr_db"]) / 10 | |
| # Penalize high spectral flatness (noisy) | |
| if metrics["spectral_flatness"] > 0.5: | |
| score -= (metrics["spectral_flatness"] - 0.5) * 2 | |
| # Penalize clipping | |
| if metrics["clipping_ratio"] > 0.01: | |
| score -= metrics["clipping_ratio"] * 10 | |
| # Penalize low dynamic range | |
| if metrics["dynamic_range_db"] < 10: | |
| score -= (10 - metrics["dynamic_range_db"]) / 5 | |
| # Penalize very low or very high energy | |
| if metrics["rms_energy"] < 0.01: | |
| score -= 1.0 | |
| elif metrics["rms_energy"] > 0.5: | |
| score -= 0.5 | |
| # Clip to valid range | |
| score = np.clip(score, 1.0, 5.0) | |
| return float(score) | |
| def _get_quality_level(mos_score: float) -> str: | |
| """ | |
| Get quality level description from MOS score | |
| Args: | |
| mos_score: MOS score (1-5) | |
| Returns: | |
| Quality level string | |
| """ | |
| if mos_score >= 4.5: | |
| return "Excellent" | |
| elif mos_score >= 4.0: | |
| return "Good" | |
| elif mos_score >= 3.0: | |
| return "Fair" | |
| elif mos_score >= 2.0: | |
| return "Poor" | |
| else: | |
| return "Bad" | |
| def compare_quality( | |
| self, | |
| audio_path1: Union[str, Path], | |
| audio_path2: Union[str, Path] | |
| ) -> dict: | |
| """ | |
| Compare quality between two audio files | |
| Args: | |
| audio_path1: First audio file | |
| audio_path2: Second audio file | |
| Returns: | |
| Dict with comparison results | |
| """ | |
| result1 = self.predict(audio_path1, return_details=True) | |
| result2 = self.predict(audio_path2, return_details=True) | |
| comparison = { | |
| "audio1": { | |
| "path": str(audio_path1), | |
| "mos": result1["mos_score"], | |
| "quality": result1["quality_level"] | |
| }, | |
| "audio2": { | |
| "path": str(audio_path2), | |
| "mos": result2["mos_score"], | |
| "quality": result2["quality_level"] | |
| }, | |
| "difference": result1["mos_score"] - result2["mos_score"], | |
| "better": "audio1" if result1["mos_score"] > result2["mos_score"] else "audio2" | |
| } | |
| return comparison | |
| def __repr__(self): | |
| return f"MOSPredictor(device={self.device})" | |
| def main(): | |
| """Demo usage of MOSPredictor""" | |
| print("=" * 60) | |
| print("MOS Predictor Demo") | |
| print("=" * 60) | |
| # Initialize | |
| predictor = MOSPredictor(device="cuda") | |
| print("\n✓ MOS Predictor ready!") | |
| print(" Score range: 1-5") | |
| print(" 5 = Excellent, 4 = Good, 3 = Fair, 2 = Poor, 1 = Bad") | |
| print("\n Quality metrics:") | |
| print(" - SNR (Signal-to-Noise Ratio)") | |
| print(" - Spectral Flatness") | |
| print(" - Zero Crossing Rate") | |
| print(" - Dynamic Range") | |
| print(" - Clipping Detection") | |
| print("\n" + "=" * 60) | |
| if __name__ == "__main__": | |
| main() | |