import glob import os import re from pesq import pesq import soundfile as sf import torch from torchmetrics.audio import ScaleInvariantSignalNoiseRatio import argparse import numpy as np import warnings from scipy.linalg import sqrtm from tqdm import tqdm import torchaudio import torchaudio.transforms as T import statistics # <-- 新增导入,用于计算平均值和标准差 from audiobox_aesthetics.infer import initialize_predictor warnings.filterwarnings("ignore") try: from transformers import ClapModel, ClapProcessor except ImportError: print("Error: The 'transformers' library is not installed.") print("Please install it to run FAD-CLAP calculations:") print("pip install torch transformers") exit(1) def multi_mel_snr(reference, prediction, sr=48000): """Compute Multi-Mel-SNR between reference and prediction.""" if not isinstance(reference, torch.Tensor): reference = torch.from_numpy(reference).float() if not isinstance(prediction, torch.Tensor): prediction = torch.from_numpy(prediction).float() # Scale-invariant normalization alpha = torch.dot(reference, prediction) / (torch.dot(prediction, prediction) + 1e-8) prediction = alpha * prediction # Three mel configurations configs = [ (512, 256, 80), # (n_fft, hop_length, n_mels) (1024, 512, 128), (2048, 1024, 192) ] snrs = [] for n_fft, hop, n_mels in configs: mel = torchaudio.transforms.MelSpectrogram( sample_rate=sr, n_fft=n_fft, hop_length=hop, n_mels=n_mels, f_min=0, f_max=24000, power=2.0 ) M_ref = mel(reference) M_pred = mel(prediction) snr = 10 * torch.log10(M_ref.pow(2).sum() / ((M_ref - M_pred).pow(2).sum() + 1e-8)) snrs.append(snr.item()) return sum(snrs) / len(snrs) def load_audio(file_path, target_sr=48000): """加载音频文件,并将其重采样到目标采样率 (target_sr)。""" try: wav, samplerate = sf.read(file_path) if wav.ndim > 1: wav = wav.T else: wav = wav[np.newaxis, :] wav_tensor = torch.from_numpy(wav).float() if samplerate != target_sr: print(f"Warning: Resampling audio from {samplerate} to {target_sr}") resampler = T.Resample(orig_freq=samplerate, new_freq=target_sr) wav_tensor = resampler(wav_tensor) return wav_tensor except Exception as e: return None def get_clap_embeddings(file_paths, model, processor, device, batch_size=16): model.to(device) all_embeddings = [] for i in tqdm(range(0, len(file_paths), batch_size), desc=" Calculating embeddings", ncols=100, leave=False): batch_paths = file_paths[i:i+batch_size] audio_batch = [] for path in batch_paths: try: wav_tensor = load_audio(path, target_sr=48000) if wav_tensor is None: continue for channel in wav_tensor: audio_batch.append(channel.numpy()) except Exception: continue if not audio_batch: continue try: inputs = processor(audios=audio_batch, sampling_rate=48000, return_tensors="pt", padding=True) inputs = {key: val.to(device) for key, val in inputs.items()} with torch.no_grad(): audio_features = model.get_audio_features(**inputs) all_embeddings.append(audio_features.cpu().numpy()) except Exception: continue if not all_embeddings: return np.array([]) return np.concatenate(all_embeddings, axis=0) def calculate_frechet_distance(embeddings1, embeddings2): if embeddings1.shape[0] < 2 or embeddings2.shape[0] < 2: return None mu1, mu2 = np.mean(embeddings1, axis=0), np.mean(embeddings2, axis=0) sigma1, sigma2 = np.cov(embeddings1, rowvar=False), np.cov(embeddings2, rowvar=False) ssdiff = np.sum((mu1 - mu2)**2.0) try: covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False) except Exception: return None if np.iscomplexobj(covmean): covmean = covmean.real fad_score = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) return fad_score def find_matching_pairs(target_dir, output_dir, target_index): """ 找到target和output文件夹中的匹配文件对 假设target文件名为: 0.flac, 1.flac, ..., 249.flac output文件名为: {target_id}_DT{index}.flac """ pairs = [] target_files = glob.glob(os.path.join(target_dir, "*.*")) target_files.sort() print(f"Found {len(target_files)} target files in {target_dir}") for target_file in target_files: target_id = os.path.splitext(os.path.basename(target_file))[0] output_pattern = os.path.join(output_dir, f"{target_id}_DT*.*") matching_outputs = glob.glob(output_pattern) if target_index is not None: regex = re.compile(rf"^{re.escape(target_id)}_DT({target_index})\.\w+$") else: regex = re.compile(rf"^{re.escape(target_id)}_DT\d+\.\w+$") matching_outputs = [f for f in matching_outputs if regex.match(os.path.basename(f))] matching_outputs.sort() if matching_outputs: print(f"Target {target_id}: found {len(matching_outputs)} output files") for output_file in matching_outputs: pairs.append((target_file, output_file)) else: print(f"Target {target_id}: no matching output files found") return pairs # --- 新增 PESQ 计算函数 --- def calculate_pesq(target_wav, output_wav, target_sr=48000, pesq_sr=16000): """ 计算 PESQ 分数 (通常使用 16kHz 宽带模式)。 target_wav 和 output_wav 必须是相同的单声道/双声道张量,且已对齐。 """ # 确保输入 Tensor 是单声道 (C=1) # WAV shape 通常是 [C, L]. 如果 C > 1, 我们将其转换为单声道。 # 最简单的做法是取第一个声道 [0, :] if target_wav.ndim > 1 and target_wav.shape[0] > 1: # 提取第一个声道 target_wav = target_wav[0:1, :] if output_wav.ndim > 1 and output_wav.shape[0] > 1: # 提取第一个声道 output_wav = output_wav[0:1, :] # 将 Tensor 转换为 numpy 数组 target_np = target_wav.squeeze(0).numpy() output_np = output_wav.squeeze(0).numpy() # 确保是单声道进行 PESQ 计算 if target_np.ndim > 1: # 如果是多声道,取第一个声道或平均 (这里取第一个声道) target_np = target_np[0] output_np = output_np[0] # 重采样到 PESQ 要求的采样率 (16000 Hz) if target_sr != pesq_sr: resampler = T.Resample(orig_freq=target_sr, new_freq=pesq_sr) target_resampled = resampler(target_wav).squeeze(0).numpy() output_resampled = resampler(output_wav).squeeze(0).numpy() else: target_resampled = target_np output_resampled = output_np try: # 使用 wideband (wb) 模式,因为我们重采样到 16kHz score = pesq(pesq_sr, target_resampled, output_resampled, 'wb') return score except Exception as e: print(f"Warning: PESQ calculation failed for a pair. Error: {e}") return float('nan') def main(): parser = argparse.ArgumentParser(description="Calculate SI-SNR and FAD-CLAP for audio pairs. All audio is resampled to 48000Hz.") parser.add_argument("--target_dir", '-t', required=True, type=str, help="Path to target audio directory") parser.add_argument("--output_dir", '-o', required=True, type=str, help="Path to output audio directory") parser.add_argument("--target_index", '-i', type=str, help="Index of target audio files, e.g. '11|12'") parser.add_argument("--batch_size", type=int, default=16, help="Batch size for FAD-CLAP embedding calculation.") parser.add_argument("--output_file", type=str, help="Filename to save all evaluation results.") # 测评指标开关 parser.add_argument("--calc_sisnr", action="store_true", help="Calculate Scale-Invariant SNR (SI-SNR).") parser.add_argument("--calc_pesq", action="store_true", help="Calculate Perceptual Evaluation of Speech Quality (PESQ).") parser.add_argument("--calc_aesthetics", action="store_true", help="Calculate AudioBox Aesthetics MOS.") parser.add_argument("--calc_fad_clap", default=True, action="store_true", help="Calculate Frechet Audio Distance (FAD-CLAP).") parser.add_argument("--calc_mel_snr", default=True, action="store_true", help="Calculate Multi-Mel-SNR.") # <-- Multi-Mel-SNR 开关 args = parser.parse_args() if not args.output_file: args.output_file = (args.output_dir[:-1] if args.output_dir.endswith('/') else args.output_dir) if args.target_index: args.output_file += f"_{args.target_index}" args.output_file += ".txt" # 初始化 AudioBox Aesthetics Predictor AXES_NAME = ["CE", "CU", "PC", "PQ"] LOCAL_AESTHETICS_CKPT = "/inspire/hdd/global_user/chenxie-25019/HaoQiu/EVAL_MODEL/audiobox/audiobox_aes_checkpoint.pt" try: assert args.calc_aesthetics, "AudioBox Aesthetics is not enabled" print("\nLoading AudioBox Aesthetics predictor...") aesthetics_predictor = initialize_predictor(ckpt=LOCAL_AESTHETICS_CKPT) print("AudioBox Aesthetics predictor loaded successfully.") except Exception as e: print(f"Error loading AudioBox Aesthetics predictor: {e}. Aesthetics calculation will be skipped.") aesthetics_predictor = None # 初始化文件写入 RESULTS_FILENAME = args.output_file if os.path.exists(RESULTS_FILENAME): raise Exception(f"Output file already exists: {RESULTS_FILENAME}") results_file = open(RESULTS_FILENAME, 'w', encoding='utf-8') results_file.write("--- Audio Evaluation Results ---\n") print(f"所有结果将被写入文件: {RESULTS_FILENAME}") sisnr_calculator = ScaleInvariantSignalNoiseRatio() all_target_paths = [] all_output_paths = [] all_sisnr_values = [] all_pesq_values = [] all_mel_snr_values = [] all_aesthetics_values = {axis: [] for axis in AXES_NAME} # ---------------------------------------------------- # PHASE 1: 遍历文件列表,计算 SI-SNR,收集路径 # ---------------------------------------------------- print("\n--- Calculating SI-SNR (48kHz) for each pair ---") results_file.write("\n--- Pairwise SI-SNR (dB) ---\n") TARGET_SR = 48000 def calculate_pairwise_metrics(target_path, output_path, args, results_list): if not os.path.exists(target_path) or not os.path.exists(output_path): raise Exception(f"Skipping, file not found: {target_path} -> {output_path}") target_wav = load_audio(target_path, TARGET_SR) output_wav = load_audio(output_path, TARGET_SR) if target_wav is None or output_wav is None: raise Exception(f"Skipping, waveform not loaded: {target_path} -> {output_path}") if target_wav.shape[0] != output_wav.shape[0]: print(f"Warning: shape mismatch: {target_path} -> {output_path}") if target_wav.shape[0] not in [1, 2]: raise Exception(f"Skipping, unsupported shape: {target_path} -> {output_path}") if output_wav.shape[0] not in [1, 2]: raise Exception(f"Skipping, unsupported shape: {target_path} -> {output_path}") if target_wav.shape[0] > output_wav.shape[0]: # 2 vs 1 output_wav = output_wav.repeat(2, 1) else: # 1 vs 2 output_wav = output_wav.mean(dim=0, keepdim=True) min_len = min(target_wav.shape[-1], output_wav.shape[-1]) target_wav = target_wav[..., :min_len] output_wav = output_wav[..., :min_len] if target_wav.shape[-1] == 0: raise Exception(f"Skipping, zero-length waveform: {target_path} -> {output_path}") # --- SI-SNR part --- sisnr_val = float('nan') if args.calc_sisnr: sisnr_val = sisnr_calculator(output_wav, target_wav).item() results_list['sisnr'].append(sisnr_val) # --- PESQ part --- pesq_val = float('nan') if args.calc_pesq: pesq_val = calculate_pesq(target_wav, output_wav, TARGET_SR) results_list['pesq'].append(pesq_val) # --- Multi-Mel-SNR part --- mel_snr_val = float('nan') if args.calc_mel_snr: # Multi-Mel-SNR 假设单声道输入,故对每个声道计算并平均 mel_snrs = [] for ch in range(target_wav.shape[0]): # 注意:multi_mel_snr 内部需要进行 SI-Norm,这里传入原始 wav mel_snr_val_ch = multi_mel_snr(target_wav[ch], output_wav[ch], sr=TARGET_SR) mel_snrs.append(mel_snr_val_ch) mel_snr_val = sum(mel_snrs) / len(mel_snrs) if mel_snrs else float('nan') results_list['mel_snr'].append(mel_snr_val) output_str = f"{target_path}|{output_path}" if args.calc_sisnr: output_str += f"|SI-SNR:{sisnr_val:.4f}" if args.calc_pesq: output_str += f"|PESQ:{pesq_val:.4f}" if args.calc_mel_snr: output_str += f"|Mel-SNR:{mel_snr_val:.4f}" print(output_str) all_target_paths.append(target_path) all_output_paths.append(output_path) all_pairwise_values = { 'sisnr': [], 'pesq': [], 'mel_snr': [] } print("--- Finding matching file pairs ---") pairs = find_matching_pairs(args.target_dir, args.output_dir, args.target_index) print(f"Found {len(pairs)} file pairs") for target_path, output_path in pairs: try: calculate_pairwise_metrics(target_path, output_path, args, all_pairwise_values) except Exception as e: print(f"Error processing {target_path} -> {output_path}: {e}") continue # ---------------------------------------------------- # PHASE 2: 批量计算 AudioBox Aesthetics 分数 # ---------------------------------------------------- AESTHETICS_CHUNK_SIZE = 64 if args.calc_aesthetics and aesthetics_predictor and all_output_paths: print("\n--- Calculating AudioBox Aesthetics Scores (Batch) ---") # 循环处理分块 for i in tqdm(range(0, len(all_output_paths), AESTHETICS_CHUNK_SIZE), desc=" Aesthetics chunks"): # 提取当前批次的路径 chunk_paths = all_output_paths[i:i + AESTHETICS_CHUNK_SIZE] aesthetics_input_list = [{"path": p} for p in chunk_paths] try: # 批量执行推理 (Chunked Inference) aesthetics_results = aesthetics_predictor.forward(aesthetics_input_list) # 结果匹配与收集 (与上一个回答的修正逻辑一致) num_outputs = len(chunk_paths) num_results = len(aesthetics_results) for j in range(num_outputs): if j < num_results and all(axis in aesthetics_results[j] for axis in AXES_NAME): score_dict = aesthetics_results[j] for axis in AXES_NAME: all_aesthetics_values[axis].append(score_dict[axis]) else: for axis in AXES_NAME: all_aesthetics_values[axis].append(float('nan')) except Exception as e: # 仍然捕获 OOM 或其他异常 print(f"\nError in chunk {i//AESTHETICS_CHUNK_SIZE}: {e}. Skipping chunk.") # 填充当前整个 chunk 为 NaN for axis in AXES_NAME: all_aesthetics_values[axis].extend([float('nan')] * len(chunk_paths)) # 如果是 OOM 错误,可能需要提前停止,或者尝试更小的 AESTHETICS_CHUNK_SIZE if "CUDA out of memory" in str(e): print("FATAL OOM: Please reduce AESTHETICS_CHUNK_SIZE and restart.") # 这里可以考虑 break 或 sys.exit() # 补全 Aesthetics 列表(如果未计算),确保长度与 num_pairs 匹配 if not args.calc_aesthetics or not all_output_paths: if len(all_target_paths) > 0: for axis in AXES_NAME: # 只在列表长度不一致时进行填充(避免重复填充) if len(all_aesthetics_values[axis]) < len(all_target_paths): all_aesthetics_values[axis].extend([float('nan')] * (len(all_target_paths) - len(all_aesthetics_values[axis]))) # ---------------------------------------------------- # PHASE 3: 写入逐行结果 (SI-SNR 和 Aesthetics) # ---------------------------------------------------- # 检查数据长度是否一致 num_pairs = len(all_target_paths) for metric_name, scores in all_pairwise_values.items(): if len(scores) != num_pairs: # 如果某个列表的长度不匹配,说明计算或收集过程中出现了错误 raise RuntimeError(f"内部错误:指标 '{metric_name}' 的结果数量 ({len(scores)}) 与文件对数量 ({num_pairs}) 不匹配。") # 检查 Aesthetics 指标的长度是否与文件对数量一致 if args.calc_aesthetics: for axis in AXES_NAME: scores = all_aesthetics_values[axis] if len(scores) != num_pairs: raise RuntimeError(f"内部错误:Aesthetics 指标 '{axis}' 的结果数量 ({len(scores)}) 与文件对数量 ({num_pairs}) 不匹配。") # 写入新的列头 results_file.write("\n--- Pairwise Metrics ---\n") # 动态构建列头字符串 header_metrics = f"{'Target Filename':<30}|{'Output Filename':<30}" if args.calc_sisnr: header_metrics += f"|{'SI-SNR (dB)':<15}" if args.calc_pesq: header_metrics += f"|{'PESQ':<8}" if args.calc_mel_snr: # <-- 新增 Mel-SNR 列头 header_metrics += f"|{'Mel-SNR (dB)':<15}" if args.calc_aesthetics: for axis in AXES_NAME: header_metrics += f"|{axis:<10}" # Aesthetics 的四个维度 # 写入列头分隔线 results_file.write(header_metrics + "\n") results_file.write("-" * len(header_metrics) + "\n") print("\n--- Writing results to file ---") # ... (循环 i in range(num_pairs) 不变) for i in tqdm(range(num_pairs), desc=" Writing results", ncols=100): target_filename = os.path.basename(all_target_paths[i]) output_filename = os.path.basename(all_output_paths[i]) result_line = f"{target_filename:<30}|{output_filename:<30}" if args.calc_sisnr: sisnr_item = all_pairwise_values['sisnr'][i] result_line += f"|{sisnr_item:<15.4f}" if args.calc_pesq: pesq_item = all_pairwise_values['pesq'][i] pesq_str = f"{pesq_item:<8.4f}" if not np.isnan(pesq_item) else "N/A " result_line += f"|{pesq_str}" if args.calc_mel_snr: mel_snr_item = all_pairwise_values['mel_snr'][i] mel_snr_str = f"{mel_snr_item:<15.4f}" if not np.isnan(mel_snr_item) else "N/A " result_line += f"|{mel_snr_str}" # 构造 Aesthetics 部分 aesthetics_part = "" for axis in AXES_NAME: score = all_aesthetics_values[axis][i] # 从对应的列表中取出分数 # 格式化 Aesthetics 分数 aesthetics_str = f"{score:.4f}" if not np.isnan(score) else "N/A" aesthetics_part += f"|{aesthetics_str:<10}" # 写入文件 results_file.write(result_line + aesthetics_part + "\n") # ---------------------------------------------------- # PHASE 4: 总体统计参数计算 (SI-SNR, Aesthetics) # ---------------------------------------------------- results_file.write("\n\n--- Overall Statistical Metrics ---\n") # SI-SNR 统计 if args.calc_sisnr and all_pairwise_values['sisnr']: scores = all_pairwise_values['sisnr'] if scores: avg_sisnr = statistics.mean(scores) std_sisnr = statistics.stdev(scores) if len(scores) > 1 else 0.0 # 写入平均值和标准差 results_file.write(f"SI-SNR (dB) Average: {avg_sisnr:.4f}\n") results_file.write(f"SI-SNR (dB) Std Dev: {std_sisnr:.4f}\n") else: results_file.write("No valid SI-SNR values were calculated.\n") # PESQ 统计 if args.calc_pesq and all_pairwise_values['pesq']: scores = all_pairwise_values['pesq'] valid_pesq_scores = [s for s in scores if not np.isnan(s)] if valid_pesq_scores: avg_pesq = statistics.mean(valid_pesq_scores) std_pesq = statistics.stdev(valid_pesq_scores) if len(valid_pesq_scores) > 1 else 0.0 results_file.write(f"\nPESQ Average: {avg_pesq:.4f}\n") results_file.write(f"PESQ Std Dev: {std_pesq:.4f} (from {len(valid_pesq_scores)} samples)\n") else: results_file.write("\nNo valid PESQ values were calculated.\n") # Multi-Mel-SNR 统计 if args.calc_mel_snr and all_pairwise_values['mel_snr']: scores = all_pairwise_values['mel_snr'] valid_scores = [s for s in scores if not np.isnan(s)] if valid_scores: avg_mel_snr = statistics.mean(valid_scores) std_mel_snr = statistics.stdev(valid_scores) if len(valid_scores) > 1 else 0.0 results_file.write(f"\nMulti-Mel-SNR Average: {avg_mel_snr:.4f}\n") results_file.write(f"Multi-Mel-SNR Std Dev: {std_mel_snr:.4f} (from {len(valid_scores)} samples)\n") else: results_file.write("\nNo valid Multi-Mel-SNR values were calculated.\n") # Aesthetics 统计 results_file.write("\n--- Aesthetics MOS ---\n") for axis in AXES_NAME: scores = all_aesthetics_values[axis] valid_scores = [s for s in scores if not np.isnan(s)] if valid_scores: avg_aesthetics = statistics.mean(valid_scores) std_aesthetics = statistics.stdev(valid_scores) if len(valid_scores) > 1 else 0.0 # 写入结果 results_file.write(f" {axis} (Avg/Std): {avg_aesthetics:.4f} / {std_aesthetics:.4f} (from {len(valid_scores)} samples)\n") else: results_file.write(f" {axis} (Avg/Std): N/A (No valid scores calculated)\n") # ---------------------------------------------------- # --- FAD-CLAP 计算 --- # ---------------------------------------------------- if args.calc_fad_clap: print("\n--- Calculating FAD-CLAP (48kHz) ---") if not all_target_paths: results_file.write("\nFAD-CLAP: Skipped (No valid file pairs found).\n") else: clap_model = None clap_processor = None try: results_file.write(f"\nTotal pairs for FAD-CLAP: {len(all_target_paths)}\n") print("Loading CLAP model...") LOCAL_MODEL_PATH = "/inspire/hdd/global_user/chenxie-25019/HaoQiu/EVAL_MODEL/clap-model" # 您下载的模型路径 clap_model = ClapModel.from_pretrained(LOCAL_MODEL_PATH, local_files_only=True) clap_processor = ClapProcessor.from_pretrained(LOCAL_MODEL_PATH, local_files_only=True) clap_model.eval() print("CLAP model loaded successfully.") except Exception as e: error_msg = f"Fatal Error: Could not load CLAP model. Error: {e}" print(error_msg) results_file.write(f"\nFAD-CLAP: {error_msg}\n") if clap_model and clap_processor: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") print("\nCalculating embeddings for all target files...") target_embeddings = get_clap_embeddings(all_target_paths, clap_model, clap_processor, device, args.batch_size) print("Calculating embeddings for all output files...") output_embeddings = get_clap_embeddings(all_output_paths, clap_model, clap_processor, device, args.batch_size) if target_embeddings.size > 0 and output_embeddings.size > 0: print("Calculating Frechet Audio Distance (FAD)...") fad_score = calculate_frechet_distance(target_embeddings, output_embeddings) if fad_score is not None: final_fad_output = f"\nOverall FAD-CLAP Score: {fad_score:.4f}" print(final_fad_output) results_file.write(final_fad_output + "\n") else: msg = "\nCould not calculate FAD-CLAP score." print(msg) results_file.write(f"\nFAD-CLAP: {msg}\n") else: msg = "\nCould not calculate FAD-CLAP due to issues with embedding generation." print(msg) results_file.write(f"\nFAD-CLAP: {msg}\n") # 关闭文件句柄 results_file.write("\n--- End of Report ---") results_file.close() print(f"\nDone!!!! Save the result into {RESULTS_FILENAME}。") if __name__ == "__main__": main()