File size: 9,573 Bytes
5ffccae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
"""
MOS (Mean Opinion Score) Predictor Module
Automated quality assessment for synthesized speech
"""

import torch
import numpy as np
import librosa
from pathlib import Path
from typing import Union, Optional
import warnings
warnings.filterwarnings('ignore')

try:
    from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
except ImportError:
    print("Warning: transformers not installed. Run: pip install transformers")
    Wav2Vec2Processor = None
    Wav2Vec2ForSequenceClassification = None


class MOSPredictor:
    """
    Mean Opinion Score (MOS) prediction for speech quality assessment
    
    Predicts human-perceived naturalness on a 1-5 scale:
    - 5: Excellent (natural, no artifacts)
    - 4: Good (minor artifacts)
    - 3: Fair (noticeable artifacts)
    - 2: Poor (significant artifacts)
    - 1: Bad (unintelligible)
    """
    
    def __init__(
        self,
        model_name: str = "microsoft/wavlm-base-plus",
        device: str = "cuda"
    ):
        """
        Initialize MOS Predictor
        
        Args:
            model_name: Pre-trained model for quality assessment
            device: Device to run on ('cuda' or 'cpu')
        """
        self.device = device if torch.cuda.is_available() else "cpu"
        self.model_name = model_name
        
        print(f"📊 Initializing MOS Predictor on {self.device}...")
        
        # Use heuristic-based quality assessment (no model needed)
        # For production, consider NISQA or fine-tuned models
        self.processor = None
        self.model = None
        
        print("✓ MOS Predictor initialized!")
        print("   Using heuristic-based quality assessment")
        print("   For production, consider NISQA or fine-tuned models")
    
    def predict(
        self,
        audio_path: Union[str, Path],
        return_details: bool = False
    ) -> Union[float, dict]:
        """
        Predict MOS score for audio file
        
        Args:
            audio_path: Path to audio file
            return_details: Return detailed quality metrics
        
        Returns:
            MOS score (1-5) or dict with detailed metrics
        """
        audio_path = Path(audio_path)
        
        if not audio_path.exists():
            raise FileNotFoundError(f"Audio file not found: {audio_path}")
        
        try:
            # Load audio
            audio, sr = librosa.load(str(audio_path), sr=16000)
            
            # Compute quality metrics
            metrics = self._compute_quality_metrics(audio, sr)
            
            # Estimate MOS score (heuristic-based)
            mos_score = self._estimate_mos(metrics)
            
            if return_details:
                return {
                    "mos_score": mos_score,
                    "metrics": metrics,
                    "quality_level": self._get_quality_level(mos_score)
                }
            else:
                return mos_score
                
        except Exception as e:
            print(f"❌ Error predicting MOS for {audio_path.name}: {e}")
            raise
    
    def predict_batch(
        self,
        audio_paths: list,
        return_details: bool = False
    ) -> list:
        """
        Predict MOS scores for multiple audio files
        
        Args:
            audio_paths: List of audio file paths
            return_details: Return detailed metrics
        
        Returns:
            List of MOS scores or detailed dicts
        """
        results = []
        
        print(f"📊 Predicting MOS for {len(audio_paths)} files...")
        
        for audio_path in audio_paths:
            try:
                result = self.predict(audio_path, return_details=return_details)
                results.append(result)
                
                if not return_details:
                    print(f"   {Path(audio_path).name}: MOS = {result:.2f}")
                
            except Exception as e:
                print(f"⚠️  Skipping {audio_path}: {e}")
                results.append(None)
        
        return results
    
    def _compute_quality_metrics(
        self,
        audio: np.ndarray,
        sr: int
    ) -> dict:
        """
        Compute audio quality metrics
        
        Args:
            audio: Audio array
            sr: Sample rate
        
        Returns:
            Dict of quality metrics
        """
        metrics = {}
        
        # 1. Signal-to-Noise Ratio (SNR) estimation
        # Estimate noise floor from silent regions
        energy = librosa.feature.rms(y=audio)[0]
        noise_threshold = np.percentile(energy, 10)
        signal_threshold = np.percentile(energy, 90)
        snr_estimate = 20 * np.log10((signal_threshold + 1e-8) / (noise_threshold + 1e-8))
        metrics["snr_db"] = float(snr_estimate)
        
        # 2. Spectral Flatness (measure of tonality vs noise)
        spectral_flatness = librosa.feature.spectral_flatness(y=audio)
        metrics["spectral_flatness"] = float(np.mean(spectral_flatness))
        
        # 3. Zero Crossing Rate (measure of noisiness)
        zcr = librosa.feature.zero_crossing_rate(audio)
        metrics["zero_crossing_rate"] = float(np.mean(zcr))
        
        # 4. Spectral Centroid (brightness)
        spectral_centroid = librosa.feature.spectral_centroid(y=audio, sr=sr)
        metrics["spectral_centroid"] = float(np.mean(spectral_centroid))
        
        # 5. RMS Energy (overall loudness)
        rms = librosa.feature.rms(y=audio)
        metrics["rms_energy"] = float(np.mean(rms))
        
        # 6. Clipping detection
        clipping_ratio = np.sum(np.abs(audio) > 0.99) / len(audio)
        metrics["clipping_ratio"] = float(clipping_ratio)
        
        # 7. Dynamic range
        dynamic_range = 20 * np.log10((np.max(np.abs(audio)) + 1e-8) / (np.mean(np.abs(audio)) + 1e-8))
        metrics["dynamic_range_db"] = float(dynamic_range)
        
        return metrics
    
    def _estimate_mos(self, metrics: dict) -> float:
        """
        Estimate MOS score from quality metrics (heuristic-based)
        
        Args:
            metrics: Quality metrics dict
        
        Returns:
            Estimated MOS score (1-5)
        """
        score = 5.0  # Start with perfect score
        
        # Penalize low SNR
        if metrics["snr_db"] < 20:
            score -= (20 - metrics["snr_db"]) / 10
        
        # Penalize high spectral flatness (noisy)
        if metrics["spectral_flatness"] > 0.5:
            score -= (metrics["spectral_flatness"] - 0.5) * 2
        
        # Penalize clipping
        if metrics["clipping_ratio"] > 0.01:
            score -= metrics["clipping_ratio"] * 10
        
        # Penalize low dynamic range
        if metrics["dynamic_range_db"] < 10:
            score -= (10 - metrics["dynamic_range_db"]) / 5
        
        # Penalize very low or very high energy
        if metrics["rms_energy"] < 0.01:
            score -= 1.0
        elif metrics["rms_energy"] > 0.5:
            score -= 0.5
        
        # Clip to valid range
        score = np.clip(score, 1.0, 5.0)
        
        return float(score)
    
    @staticmethod
    def _get_quality_level(mos_score: float) -> str:
        """
        Get quality level description from MOS score
        
        Args:
            mos_score: MOS score (1-5)
        
        Returns:
            Quality level string
        """
        if mos_score >= 4.5:
            return "Excellent"
        elif mos_score >= 4.0:
            return "Good"
        elif mos_score >= 3.0:
            return "Fair"
        elif mos_score >= 2.0:
            return "Poor"
        else:
            return "Bad"
    
    def compare_quality(
        self,
        audio_path1: Union[str, Path],
        audio_path2: Union[str, Path]
    ) -> dict:
        """
        Compare quality between two audio files
        
        Args:
            audio_path1: First audio file
            audio_path2: Second audio file
        
        Returns:
            Dict with comparison results
        """
        result1 = self.predict(audio_path1, return_details=True)
        result2 = self.predict(audio_path2, return_details=True)
        
        comparison = {
            "audio1": {
                "path": str(audio_path1),
                "mos": result1["mos_score"],
                "quality": result1["quality_level"]
            },
            "audio2": {
                "path": str(audio_path2),
                "mos": result2["mos_score"],
                "quality": result2["quality_level"]
            },
            "difference": result1["mos_score"] - result2["mos_score"],
            "better": "audio1" if result1["mos_score"] > result2["mos_score"] else "audio2"
        }
        
        return comparison
    
    def __repr__(self):
        return f"MOSPredictor(device={self.device})"


def main():
    """Demo usage of MOSPredictor"""
    print("=" * 60)
    print("MOS Predictor Demo")
    print("=" * 60)
    
    # Initialize
    predictor = MOSPredictor(device="cuda")
    
    print("\n✓ MOS Predictor ready!")
    print("   Score range: 1-5")
    print("   5 = Excellent, 4 = Good, 3 = Fair, 2 = Poor, 1 = Bad")
    print("\n   Quality metrics:")
    print("   - SNR (Signal-to-Noise Ratio)")
    print("   - Spectral Flatness")
    print("   - Zero Crossing Rate")
    print("   - Dynamic Range")
    print("   - Clipping Detection")
    
    print("\n" + "=" * 60)


if __name__ == "__main__":
    main()