TTS-with-VoiceCloning / src /mos_predictor.py
saadmannan's picture
initial commit
5ffccae
"""
MOS (Mean Opinion Score) Predictor Module
Automated quality assessment for synthesized speech
"""
import torch
import numpy as np
import librosa
from pathlib import Path
from typing import Union, Optional
import warnings
warnings.filterwarnings('ignore')
try:
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
except ImportError:
print("Warning: transformers not installed. Run: pip install transformers")
Wav2Vec2Processor = None
Wav2Vec2ForSequenceClassification = None
class MOSPredictor:
"""
Mean Opinion Score (MOS) prediction for speech quality assessment
Predicts human-perceived naturalness on a 1-5 scale:
- 5: Excellent (natural, no artifacts)
- 4: Good (minor artifacts)
- 3: Fair (noticeable artifacts)
- 2: Poor (significant artifacts)
- 1: Bad (unintelligible)
"""
def __init__(
self,
model_name: str = "microsoft/wavlm-base-plus",
device: str = "cuda"
):
"""
Initialize MOS Predictor
Args:
model_name: Pre-trained model for quality assessment
device: Device to run on ('cuda' or 'cpu')
"""
self.device = device if torch.cuda.is_available() else "cpu"
self.model_name = model_name
print(f"📊 Initializing MOS Predictor on {self.device}...")
# Use heuristic-based quality assessment (no model needed)
# For production, consider NISQA or fine-tuned models
self.processor = None
self.model = None
print("✓ MOS Predictor initialized!")
print(" Using heuristic-based quality assessment")
print(" For production, consider NISQA or fine-tuned models")
def predict(
self,
audio_path: Union[str, Path],
return_details: bool = False
) -> Union[float, dict]:
"""
Predict MOS score for audio file
Args:
audio_path: Path to audio file
return_details: Return detailed quality metrics
Returns:
MOS score (1-5) or dict with detailed metrics
"""
audio_path = Path(audio_path)
if not audio_path.exists():
raise FileNotFoundError(f"Audio file not found: {audio_path}")
try:
# Load audio
audio, sr = librosa.load(str(audio_path), sr=16000)
# Compute quality metrics
metrics = self._compute_quality_metrics(audio, sr)
# Estimate MOS score (heuristic-based)
mos_score = self._estimate_mos(metrics)
if return_details:
return {
"mos_score": mos_score,
"metrics": metrics,
"quality_level": self._get_quality_level(mos_score)
}
else:
return mos_score
except Exception as e:
print(f"❌ Error predicting MOS for {audio_path.name}: {e}")
raise
def predict_batch(
self,
audio_paths: list,
return_details: bool = False
) -> list:
"""
Predict MOS scores for multiple audio files
Args:
audio_paths: List of audio file paths
return_details: Return detailed metrics
Returns:
List of MOS scores or detailed dicts
"""
results = []
print(f"📊 Predicting MOS for {len(audio_paths)} files...")
for audio_path in audio_paths:
try:
result = self.predict(audio_path, return_details=return_details)
results.append(result)
if not return_details:
print(f" {Path(audio_path).name}: MOS = {result:.2f}")
except Exception as e:
print(f"⚠️ Skipping {audio_path}: {e}")
results.append(None)
return results
def _compute_quality_metrics(
self,
audio: np.ndarray,
sr: int
) -> dict:
"""
Compute audio quality metrics
Args:
audio: Audio array
sr: Sample rate
Returns:
Dict of quality metrics
"""
metrics = {}
# 1. Signal-to-Noise Ratio (SNR) estimation
# Estimate noise floor from silent regions
energy = librosa.feature.rms(y=audio)[0]
noise_threshold = np.percentile(energy, 10)
signal_threshold = np.percentile(energy, 90)
snr_estimate = 20 * np.log10((signal_threshold + 1e-8) / (noise_threshold + 1e-8))
metrics["snr_db"] = float(snr_estimate)
# 2. Spectral Flatness (measure of tonality vs noise)
spectral_flatness = librosa.feature.spectral_flatness(y=audio)
metrics["spectral_flatness"] = float(np.mean(spectral_flatness))
# 3. Zero Crossing Rate (measure of noisiness)
zcr = librosa.feature.zero_crossing_rate(audio)
metrics["zero_crossing_rate"] = float(np.mean(zcr))
# 4. Spectral Centroid (brightness)
spectral_centroid = librosa.feature.spectral_centroid(y=audio, sr=sr)
metrics["spectral_centroid"] = float(np.mean(spectral_centroid))
# 5. RMS Energy (overall loudness)
rms = librosa.feature.rms(y=audio)
metrics["rms_energy"] = float(np.mean(rms))
# 6. Clipping detection
clipping_ratio = np.sum(np.abs(audio) > 0.99) / len(audio)
metrics["clipping_ratio"] = float(clipping_ratio)
# 7. Dynamic range
dynamic_range = 20 * np.log10((np.max(np.abs(audio)) + 1e-8) / (np.mean(np.abs(audio)) + 1e-8))
metrics["dynamic_range_db"] = float(dynamic_range)
return metrics
def _estimate_mos(self, metrics: dict) -> float:
"""
Estimate MOS score from quality metrics (heuristic-based)
Args:
metrics: Quality metrics dict
Returns:
Estimated MOS score (1-5)
"""
score = 5.0 # Start with perfect score
# Penalize low SNR
if metrics["snr_db"] < 20:
score -= (20 - metrics["snr_db"]) / 10
# Penalize high spectral flatness (noisy)
if metrics["spectral_flatness"] > 0.5:
score -= (metrics["spectral_flatness"] - 0.5) * 2
# Penalize clipping
if metrics["clipping_ratio"] > 0.01:
score -= metrics["clipping_ratio"] * 10
# Penalize low dynamic range
if metrics["dynamic_range_db"] < 10:
score -= (10 - metrics["dynamic_range_db"]) / 5
# Penalize very low or very high energy
if metrics["rms_energy"] < 0.01:
score -= 1.0
elif metrics["rms_energy"] > 0.5:
score -= 0.5
# Clip to valid range
score = np.clip(score, 1.0, 5.0)
return float(score)
@staticmethod
def _get_quality_level(mos_score: float) -> str:
"""
Get quality level description from MOS score
Args:
mos_score: MOS score (1-5)
Returns:
Quality level string
"""
if mos_score >= 4.5:
return "Excellent"
elif mos_score >= 4.0:
return "Good"
elif mos_score >= 3.0:
return "Fair"
elif mos_score >= 2.0:
return "Poor"
else:
return "Bad"
def compare_quality(
self,
audio_path1: Union[str, Path],
audio_path2: Union[str, Path]
) -> dict:
"""
Compare quality between two audio files
Args:
audio_path1: First audio file
audio_path2: Second audio file
Returns:
Dict with comparison results
"""
result1 = self.predict(audio_path1, return_details=True)
result2 = self.predict(audio_path2, return_details=True)
comparison = {
"audio1": {
"path": str(audio_path1),
"mos": result1["mos_score"],
"quality": result1["quality_level"]
},
"audio2": {
"path": str(audio_path2),
"mos": result2["mos_score"],
"quality": result2["quality_level"]
},
"difference": result1["mos_score"] - result2["mos_score"],
"better": "audio1" if result1["mos_score"] > result2["mos_score"] else "audio2"
}
return comparison
def __repr__(self):
return f"MOSPredictor(device={self.device})"
def main():
"""Demo usage of MOSPredictor"""
print("=" * 60)
print("MOS Predictor Demo")
print("=" * 60)
# Initialize
predictor = MOSPredictor(device="cuda")
print("\n✓ MOS Predictor ready!")
print(" Score range: 1-5")
print(" 5 = Excellent, 4 = Good, 3 = Fair, 2 = Poor, 1 = Bad")
print("\n Quality metrics:")
print(" - SNR (Signal-to-Noise Ratio)")
print(" - Spectral Flatness")
print(" - Zero Crossing Rate")
print(" - Dynamic Range")
print(" - Clipping Detection")
print("\n" + "=" * 60)
if __name__ == "__main__":
main()