|
|
""" |
|
|
Audio Evaluation Script |
|
|
|
|
|
This script evaluates the quality of generated audio against ground truth audio |
|
|
using a variety of metrics, including: |
|
|
- SI-SDR (Scale-Invariant Signal-to-Distortion Ratio) |
|
|
- Multi-Resolution STFT Loss |
|
|
- Multi-Resolution Mel-Spectrogram Loss |
|
|
- Phase Coherence (Per-channel and Inter-channel) |
|
|
- Loudness metrics (LUFS-I, LRA, True Peak) via ffmpeg. |
|
|
|
|
|
The script processes a directory of models, where each model directory contains |
|
|
pairs of reconstructed (_rec.wav) and ground truth (.wav) audio files. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
import sys |
|
|
import json |
|
|
import logging |
|
|
import argparse |
|
|
import subprocess |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple, Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchaudio |
|
|
import auraloss |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
|
stream=sys.stdout |
|
|
) |
|
|
|
|
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
SAMPLE_RATE = 44100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sisdr_criteria = auraloss.time.SISDRLoss().to(DEVICE) |
|
|
|
|
|
|
|
|
mel_fft_sizes = [4096, 2048, 1024, 512] |
|
|
mel_win_sizes = mel_fft_sizes |
|
|
mel_hop_sizes = [i // 4 for i in mel_fft_sizes] |
|
|
mel_criteria = auraloss.freq.MultiResolutionSTFTLoss( |
|
|
fft_sizes=mel_fft_sizes, |
|
|
hop_sizes=mel_hop_sizes, |
|
|
win_lengths=mel_win_sizes, |
|
|
sample_rate=SAMPLE_RATE, |
|
|
scale="mel", |
|
|
n_bins=64, |
|
|
perceptual_weighting=True |
|
|
).to(DEVICE) |
|
|
|
|
|
|
|
|
fft_sizes = [4096, 2048, 1024, 512, 256, 128] |
|
|
win_sizes = fft_sizes |
|
|
hop_sizes = [i // 4 for i in fft_sizes] |
|
|
stft_criteria = auraloss.freq.MultiResolutionSTFTLoss( |
|
|
fft_sizes=fft_sizes, |
|
|
hop_sizes=hop_sizes, |
|
|
win_lengths=win_sizes, |
|
|
sample_rate=SAMPLE_RATE, |
|
|
perceptual_weighting=True |
|
|
).to(DEVICE) |
|
|
|
|
|
|
|
|
def analyze_loudness(file_path: str) -> Optional[Dict[str, float]]: |
|
|
""" |
|
|
Analyzes audio file loudness using ffmpeg's ebur128 filter. |
|
|
|
|
|
Args: |
|
|
file_path: Path to the audio file. |
|
|
|
|
|
Returns: |
|
|
A dictionary with LUFS-I, LRA, and True Peak, or None on failure. |
|
|
""" |
|
|
if not Path(file_path).exists(): |
|
|
logging.warning(f"Loudness analysis skipped: File not found at {file_path}") |
|
|
return None |
|
|
|
|
|
command = [ |
|
|
"ffmpeg", |
|
|
"-nostats", |
|
|
"-i", file_path, |
|
|
"-af", "ebur128=peak=true,ametadata=mode=print:file=-", |
|
|
"-f", "null", |
|
|
"-" |
|
|
] |
|
|
|
|
|
try: |
|
|
result = subprocess.run(command, capture_output=True, text=True, check=True, encoding='utf-8') |
|
|
output_text = result.stderr |
|
|
except FileNotFoundError: |
|
|
logging.error("ffmpeg not found. Please install ffmpeg and ensure it's in your PATH.") |
|
|
return None |
|
|
except subprocess.CalledProcessError as e: |
|
|
logging.error(f"ffmpeg analysis failed for {file_path}. Error: {e.stderr}") |
|
|
return None |
|
|
|
|
|
loudness_data = {} |
|
|
|
|
|
i_match = re.search(r"^\s*I:\s*(-?[\d\.]+)\s*LUFS", output_text, re.MULTILINE) |
|
|
if i_match: |
|
|
loudness_data['LUFS-I'] = float(i_match.group(1)) |
|
|
|
|
|
lra_match = re.search(r"^\s*LRA:\s*([\d\.]+)\s*LU", output_text, re.MULTILINE) |
|
|
if lra_match: |
|
|
loudness_data['LRA'] = float(lra_match.group(1)) |
|
|
|
|
|
tp_match = re.search(r"Peak:\s*(-?[\d\.]+)\s*dBFS", output_text, re.MULTILINE) |
|
|
if tp_match: |
|
|
loudness_data['True Peak'] = float(tp_match.group(1)) |
|
|
|
|
|
if not loudness_data: |
|
|
logging.warning(f"Could not parse loudness data for {file_path}.") |
|
|
return None |
|
|
|
|
|
return loudness_data |
|
|
|
|
|
|
|
|
class PhaseCoherenceLoss(nn.Module): |
|
|
""" |
|
|
Calculates phase coherence between two audio signals. |
|
|
Adapted for stereo and multi-resolution analysis. |
|
|
""" |
|
|
def __init__(self, fft_size, hop_size, win_size, mag_threshold=1e-6, eps=1e-8): |
|
|
super().__init__() |
|
|
self.fft_size = int(fft_size) |
|
|
self.hop_size = int(hop_size) |
|
|
self.win_size = int(win_size) |
|
|
self.register_buffer("window", torch.hann_window(win_size)) |
|
|
self.mag_threshold = float(mag_threshold) |
|
|
self.eps = float(eps) |
|
|
|
|
|
def _to_complex(self, x): |
|
|
if torch.is_complex(x): |
|
|
return x |
|
|
if x.dim() >= 1 and x.size(-1) == 2: |
|
|
return torch.complex(x[..., 0], x[..., 1]) |
|
|
raise ValueError("Input must be complex or real/imag tensor.") |
|
|
|
|
|
def _stereo_stft(self, x): |
|
|
if x.dim() == 2: |
|
|
x = x.unsqueeze(0) |
|
|
B, C, T = x.shape |
|
|
stft = torch.stft(x.reshape(B * C, T), |
|
|
n_fft=self.fft_size, |
|
|
hop_length=self.hop_size, |
|
|
win_length=self.win_size, |
|
|
window=self.window, |
|
|
return_complex=True) |
|
|
return stft.view(B, C, -1, stft.size(-1)) |
|
|
|
|
|
def forward(self, pred, target): |
|
|
pred_stft = self._stereo_stft(pred) |
|
|
target_stft = self._stereo_stft(target) |
|
|
|
|
|
pred_stft = self._to_complex(pred_stft) |
|
|
target_stft = self._to_complex(target_stft) |
|
|
|
|
|
B, C, F, T = pred_stft.shape |
|
|
|
|
|
|
|
|
mag_pred = torch.abs(pred_stft) |
|
|
mag_target = torch.abs(target_stft) |
|
|
weights = mag_pred * mag_target |
|
|
mask = (weights > self.mag_threshold).to(weights.dtype) |
|
|
weights_masked = weights * mask |
|
|
|
|
|
|
|
|
delta = torch.angle(pred_stft) - torch.angle(target_stft) |
|
|
|
|
|
phasor = torch.complex(torch.cos(delta), torch.sin(delta)) |
|
|
|
|
|
|
|
|
num = torch.sum(weights_masked * phasor, dim=2) |
|
|
den = torch.sum(weights_masked, dim=2).clamp_min(self.eps) |
|
|
coherence_per_bin = torch.abs(num) / den |
|
|
|
|
|
|
|
|
|
|
|
frame_energy = torch.sum(weights_masked, dim=2) |
|
|
frame_energy_sum = torch.sum(frame_energy, dim=2).clamp_min(self.eps) |
|
|
|
|
|
|
|
|
coherence_chan = torch.sum(coherence_per_bin * frame_energy, dim=2) / frame_energy_sum |
|
|
|
|
|
|
|
|
per_channel_coherence = coherence_chan.mean(dim=0) |
|
|
|
|
|
inter_coherence = None |
|
|
if C >= 2: |
|
|
Lp, Rp = pred_stft[:, 0], pred_stft[:, 1] |
|
|
Lt, Rt = target_stft[:, 0], target_stft[:, 1] |
|
|
|
|
|
|
|
|
inter_delta = torch.angle(Lp * torch.conj(Rp)) - torch.angle(Lt * torch.conj(Rt)) |
|
|
inter_weights = torch.abs(Lp) * torch.abs(Rp) |
|
|
inter_mask = (inter_weights > self.mag_threshold).to(inter_weights.dtype) |
|
|
inter_weights_masked = inter_weights * inter_mask |
|
|
inter_phasor = torch.complex(torch.cos(inter_delta), torch.sin(inter_delta)) |
|
|
inter_num = torch.sum(inter_weights_masked * inter_phasor, dim=1) |
|
|
inter_den = torch.sum(inter_weights_masked, dim=1).clamp_min(self.eps) |
|
|
inter_coh_time = torch.abs(inter_num) / inter_den |
|
|
|
|
|
|
|
|
inter_frame_energy = torch.sum(inter_weights_masked, dim=1) |
|
|
inter_energy_sum = inter_frame_energy.sum(dim=1).clamp_min(self.eps) |
|
|
inter_coh_b = (inter_coh_time * inter_frame_energy).sum(dim=1) / inter_energy_sum |
|
|
inter_coherence = inter_coh_b.mean() |
|
|
|
|
|
return { |
|
|
"per_channel_coherence": per_channel_coherence.detach().cpu(), |
|
|
"interchannel_coherence": (inter_coherence.detach().cpu() if inter_coherence is not None else None), |
|
|
} |
|
|
|
|
|
|
|
|
class MultiResolutionPhaseCoherenceLoss(nn.Module): |
|
|
def __init__(self, fft_sizes, hop_sizes, win_sizes): |
|
|
super().__init__() |
|
|
self.criteria = nn.ModuleList([ |
|
|
PhaseCoherenceLoss(fft, hop, win) for fft, hop, win in zip(fft_sizes, hop_sizes, win_sizes) |
|
|
]) |
|
|
|
|
|
def forward(self, pred, target): |
|
|
results = [criterion(pred, target) for criterion in self.criteria] |
|
|
|
|
|
per_channel = torch.stack([r["per_channel_coherence"] for r in results]).mean(dim=0) |
|
|
inter_items = [r["interchannel_coherence"] for r in results if r["interchannel_coherence"] is not None] |
|
|
inter_channel = torch.stack(inter_items).mean() if inter_items else None |
|
|
|
|
|
return {"per_channel_coherence": per_channel, "interchannel_coherence": inter_channel} |
|
|
|
|
|
phase_coherence_criteria = MultiResolutionPhaseCoherenceLoss( |
|
|
fft_sizes=mel_fft_sizes, hop_sizes=mel_hop_sizes, win_sizes=mel_win_sizes |
|
|
).to(DEVICE) |
|
|
|
|
|
|
|
|
def find_audio_pairs(model_path: Path) -> List[Tuple[Path, Path]]: |
|
|
"""Finds pairs of reconstructed and ground truth audio files.""" |
|
|
rec_files = sorted(model_path.glob("*_vae_rec.wav")) |
|
|
pairs = [] |
|
|
for rec_file in rec_files: |
|
|
gt_file = model_path / rec_file.name.replace("_vae_rec.wav", ".wav") |
|
|
if gt_file.exists(): |
|
|
pairs.append((rec_file, gt_file)) |
|
|
else: |
|
|
logging.warning(f"Ground truth file not found for {rec_file.name}") |
|
|
return pairs |
|
|
|
|
|
|
|
|
def evaluate_pair(rec_path: Path, gt_path: Path) -> Optional[Dict[str, float]]: |
|
|
"""Evaluates a single pair of audio files.""" |
|
|
try: |
|
|
gen_wav, gen_sr = torchaudio.load(rec_path, backend="ffmpeg") |
|
|
gt_wav, gt_sr = torchaudio.load(gt_path, backend="ffmpeg") |
|
|
|
|
|
if gen_sr != SAMPLE_RATE: |
|
|
gen_wav = torchaudio.transforms.Resample(gen_sr, SAMPLE_RATE)(gen_wav) |
|
|
if gt_sr != SAMPLE_RATE: |
|
|
gt_wav = torchaudio.transforms.Resample(gt_sr, SAMPLE_RATE)(gt_wav) |
|
|
|
|
|
|
|
|
if gen_wav.shape[-1] != gt_wav.shape[-1]: |
|
|
logging.info(f"Shape Mismatched, Trimming audio files to the same length: {rec_path.name}, {gt_path.name}") |
|
|
min_len = min(gen_wav.shape[-1], gt_wav.shape[-1]) |
|
|
gen_wav, gt_wav = gen_wav[:, :min_len], gt_wav[:, :min_len] |
|
|
|
|
|
gen_wav, gt_wav = gen_wav.to(DEVICE).unsqueeze(0), gt_wav.to(DEVICE).unsqueeze(0) |
|
|
|
|
|
metrics = {} |
|
|
metrics['sisdr'] = -sisdr_criteria(gen_wav, gt_wav).item() |
|
|
metrics['mel_distance'] = mel_criteria(gen_wav, gt_wav).item() |
|
|
metrics['stft_distance'] = stft_criteria(gen_wav, gt_wav).item() |
|
|
|
|
|
phase_metrics = phase_coherence_criteria(gen_wav, gt_wav) |
|
|
metrics['per_channel_coherence'] = phase_metrics["per_channel_coherence"].mean().item() |
|
|
if phase_metrics["interchannel_coherence"] is not None: |
|
|
metrics['interchannel_coherence'] = phase_metrics["interchannel_coherence"].item() |
|
|
|
|
|
return metrics |
|
|
except Exception as e: |
|
|
logging.error(f"Error processing pair {rec_path.name}, {gt_path.name}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def process_model(model_path: Path, force_eval: bool = False, echo=True): |
|
|
"""Processes all audio pairs for a given model.""" |
|
|
logging.info(f"Processing model: {model_path.name}") |
|
|
results_file = model_path / "evaluation_results.json" |
|
|
|
|
|
if results_file.exists() and not force_eval: |
|
|
logging.info(f"Results already exist for {model_path.name}, skipping.") |
|
|
return |
|
|
|
|
|
audio_pairs = find_audio_pairs(model_path) |
|
|
if not audio_pairs: |
|
|
logging.warning(f"No valid audio pairs found for {model_path.name}.") |
|
|
return |
|
|
|
|
|
all_metrics = [] |
|
|
gen_loudness_data, gt_loudness_data = [], [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for rec_path, gt_path in tqdm(audio_pairs, desc=f"Evaluating {model_path.name}"): |
|
|
pair_metrics = evaluate_pair(rec_path, gt_path) |
|
|
if pair_metrics: |
|
|
all_metrics.append(pair_metrics) |
|
|
|
|
|
gen_loudness = analyze_loudness(str(rec_path)) |
|
|
if gen_loudness: |
|
|
gen_loudness_data.append(gen_loudness) |
|
|
|
|
|
gt_loudness = analyze_loudness(str(gt_path)) |
|
|
if gt_loudness: |
|
|
gt_loudness_data.append(gt_loudness) |
|
|
|
|
|
if echo: |
|
|
logging.info(f"Metrics for {rec_path.name}: {pair_metrics}") |
|
|
if gen_loudness: |
|
|
logging.info(f"Generated Loudness: {gen_loudness}") |
|
|
if gt_loudness: |
|
|
logging.info(f"Ground Truth Loudness: {gt_loudness}") |
|
|
|
|
|
if not all_metrics: |
|
|
logging.warning(f"No metrics could be calculated for {model_path.name}.") |
|
|
return |
|
|
|
|
|
|
|
|
summary = {"model_name": model_path.name, "file_count": len(all_metrics)} |
|
|
|
|
|
|
|
|
metric_keys = all_metrics[0].keys() |
|
|
for key in metric_keys: |
|
|
valid_values = [m[key] for m in all_metrics if key in m] |
|
|
if valid_values: |
|
|
summary[f"avg_{key}"] = float(np.mean(valid_values)) |
|
|
|
|
|
|
|
|
def _avg_loudness(data: List[Dict[str, float]], prefix: str): |
|
|
if not data: return |
|
|
for key in data[0].keys(): |
|
|
values = [d[key] for d in data if key in d] |
|
|
if values: |
|
|
summary[f"avg_{prefix}_{key.lower().replace(' ', '_')}"] = float(np.mean(values)) |
|
|
|
|
|
_avg_loudness(gen_loudness_data, "gen") |
|
|
_avg_loudness(gt_loudness_data, "gt") |
|
|
|
|
|
|
|
|
logging.info(f"Saving results for {model_path.name} to {results_file}") |
|
|
with open(results_file, 'w') as f: |
|
|
json.dump(summary, f, indent=4) |
|
|
|
|
|
|
|
|
with open(model_path / "evaluation_summary.txt", "w") as f: |
|
|
for key, value in summary.items(): |
|
|
f.write(f"{key}: {value}\n") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Run evaluation on generated audio.") |
|
|
parser.add_argument( |
|
|
"--input_dir", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Root directory containing model output folders." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--force", |
|
|
action="store_true", |
|
|
help="Force re-evaluation even if results files exist." |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--echo", |
|
|
action="store_true", |
|
|
help="Echo per-file metrics to console during evaluation." |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
root_path = Path(args.input_dir) |
|
|
if not root_path.is_dir(): |
|
|
logging.error(f"Input directory not found: {root_path}") |
|
|
sys.exit(1) |
|
|
|
|
|
model_paths = [p for p in root_path.iterdir() if p.is_dir() and not p.name.startswith('.')] |
|
|
|
|
|
logging.info(f"Found {len(model_paths)} model(s) to evaluate: {[p.name for p in model_paths]}") |
|
|
|
|
|
for model_path in sorted(model_paths): |
|
|
process_model(model_path, args.force, args.echo) |
|
|
|
|
|
logging.info("Evaluation complete.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |