Spaces:
Running
Running
| import glob | |
| import os | |
| import re | |
| from pesq import pesq | |
| import soundfile as sf | |
| import torch | |
| from torchmetrics.audio import ScaleInvariantSignalNoiseRatio | |
| import argparse | |
| import numpy as np | |
| import warnings | |
| from scipy.linalg import sqrtm | |
| from tqdm import tqdm | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| import statistics # <-- 新增导入,用于计算平均值和标准差 | |
| from audiobox_aesthetics.infer import initialize_predictor | |
| warnings.filterwarnings("ignore") | |
| try: | |
| from transformers import ClapModel, ClapProcessor | |
| except ImportError: | |
| print("Error: The 'transformers' library is not installed.") | |
| print("Please install it to run FAD-CLAP calculations:") | |
| print("pip install torch transformers") | |
| exit(1) | |
| def multi_mel_snr(reference, prediction, sr=48000): | |
| """Compute Multi-Mel-SNR between reference and prediction.""" | |
| if not isinstance(reference, torch.Tensor): | |
| reference = torch.from_numpy(reference).float() | |
| if not isinstance(prediction, torch.Tensor): | |
| prediction = torch.from_numpy(prediction).float() | |
| # Scale-invariant normalization | |
| alpha = torch.dot(reference, prediction) / (torch.dot(prediction, prediction) + 1e-8) | |
| prediction = alpha * prediction | |
| # Three mel configurations | |
| configs = [ | |
| (512, 256, 80), # (n_fft, hop_length, n_mels) | |
| (1024, 512, 128), | |
| (2048, 1024, 192) | |
| ] | |
| snrs = [] | |
| for n_fft, hop, n_mels in configs: | |
| mel = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=sr, n_fft=n_fft, hop_length=hop, | |
| n_mels=n_mels, f_min=0, f_max=24000, power=2.0 | |
| ) | |
| M_ref = mel(reference) | |
| M_pred = mel(prediction) | |
| snr = 10 * torch.log10(M_ref.pow(2).sum() / ((M_ref - M_pred).pow(2).sum() + 1e-8)) | |
| snrs.append(snr.item()) | |
| return sum(snrs) / len(snrs) | |
| def load_audio(file_path, target_sr=48000): | |
| """加载音频文件,并将其重采样到目标采样率 (target_sr)。""" | |
| try: | |
| wav, samplerate = sf.read(file_path) | |
| if wav.ndim > 1: | |
| wav = wav.T | |
| else: | |
| wav = wav[np.newaxis, :] | |
| wav_tensor = torch.from_numpy(wav).float() | |
| if samplerate != target_sr: | |
| print(f"Warning: Resampling audio from {samplerate} to {target_sr}") | |
| resampler = T.Resample(orig_freq=samplerate, new_freq=target_sr) | |
| wav_tensor = resampler(wav_tensor) | |
| return wav_tensor | |
| except Exception as e: | |
| return None | |
| def get_clap_embeddings(file_paths, model, processor, device, batch_size=16): | |
| model.to(device) | |
| all_embeddings = [] | |
| for i in tqdm(range(0, len(file_paths), batch_size), desc=" Calculating embeddings", ncols=100, leave=False): | |
| batch_paths = file_paths[i:i+batch_size] | |
| audio_batch = [] | |
| for path in batch_paths: | |
| try: | |
| wav_tensor = load_audio(path, target_sr=48000) | |
| if wav_tensor is None: | |
| continue | |
| for channel in wav_tensor: | |
| audio_batch.append(channel.numpy()) | |
| except Exception: | |
| continue | |
| if not audio_batch: | |
| continue | |
| try: | |
| inputs = processor(audios=audio_batch, sampling_rate=48000, return_tensors="pt", padding=True) | |
| inputs = {key: val.to(device) for key, val in inputs.items()} | |
| with torch.no_grad(): | |
| audio_features = model.get_audio_features(**inputs) | |
| all_embeddings.append(audio_features.cpu().numpy()) | |
| except Exception: | |
| continue | |
| if not all_embeddings: | |
| return np.array([]) | |
| return np.concatenate(all_embeddings, axis=0) | |
| def calculate_frechet_distance(embeddings1, embeddings2): | |
| if embeddings1.shape[0] < 2 or embeddings2.shape[0] < 2: | |
| return None | |
| mu1, mu2 = np.mean(embeddings1, axis=0), np.mean(embeddings2, axis=0) | |
| sigma1, sigma2 = np.cov(embeddings1, rowvar=False), np.cov(embeddings2, rowvar=False) | |
| ssdiff = np.sum((mu1 - mu2)**2.0) | |
| try: | |
| covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False) | |
| except Exception: | |
| return None | |
| if np.iscomplexobj(covmean): | |
| covmean = covmean.real | |
| fad_score = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) | |
| return fad_score | |
| def find_matching_pairs(target_dir, output_dir, target_index): | |
| """ | |
| 找到target和output文件夹中的匹配文件对 | |
| 假设target文件名为: 0.flac, 1.flac, ..., 249.flac | |
| output文件名为: {target_id}_DT{index}.flac | |
| """ | |
| pairs = [] | |
| target_files = glob.glob(os.path.join(target_dir, "*.*")) | |
| target_files.sort() | |
| print(f"Found {len(target_files)} target files in {target_dir}") | |
| for target_file in target_files: | |
| target_id = os.path.splitext(os.path.basename(target_file))[0] | |
| output_pattern = os.path.join(output_dir, f"{target_id}_DT*.*") | |
| matching_outputs = glob.glob(output_pattern) | |
| if target_index is not None: | |
| regex = re.compile(rf"^{re.escape(target_id)}_DT({target_index})\.\w+$") | |
| else: | |
| regex = re.compile(rf"^{re.escape(target_id)}_DT\d+\.\w+$") | |
| matching_outputs = [f for f in matching_outputs if regex.match(os.path.basename(f))] | |
| matching_outputs.sort() | |
| if matching_outputs: | |
| print(f"Target {target_id}: found {len(matching_outputs)} output files") | |
| for output_file in matching_outputs: | |
| pairs.append((target_file, output_file)) | |
| else: | |
| print(f"Target {target_id}: no matching output files found") | |
| return pairs | |
| # --- 新增 PESQ 计算函数 --- | |
| def calculate_pesq(target_wav, output_wav, target_sr=48000, pesq_sr=16000): | |
| """ | |
| 计算 PESQ 分数 (通常使用 16kHz 宽带模式)。 | |
| target_wav 和 output_wav 必须是相同的单声道/双声道张量,且已对齐。 | |
| """ | |
| # 确保输入 Tensor 是单声道 (C=1) | |
| # WAV shape 通常是 [C, L]. 如果 C > 1, 我们将其转换为单声道。 | |
| # 最简单的做法是取第一个声道 [0, :] | |
| if target_wav.ndim > 1 and target_wav.shape[0] > 1: | |
| # 提取第一个声道 | |
| target_wav = target_wav[0:1, :] | |
| if output_wav.ndim > 1 and output_wav.shape[0] > 1: | |
| # 提取第一个声道 | |
| output_wav = output_wav[0:1, :] | |
| # 将 Tensor 转换为 numpy 数组 | |
| target_np = target_wav.squeeze(0).numpy() | |
| output_np = output_wav.squeeze(0).numpy() | |
| # 确保是单声道进行 PESQ 计算 | |
| if target_np.ndim > 1: | |
| # 如果是多声道,取第一个声道或平均 (这里取第一个声道) | |
| target_np = target_np[0] | |
| output_np = output_np[0] | |
| # 重采样到 PESQ 要求的采样率 (16000 Hz) | |
| if target_sr != pesq_sr: | |
| resampler = T.Resample(orig_freq=target_sr, new_freq=pesq_sr) | |
| target_resampled = resampler(target_wav).squeeze(0).numpy() | |
| output_resampled = resampler(output_wav).squeeze(0).numpy() | |
| else: | |
| target_resampled = target_np | |
| output_resampled = output_np | |
| try: | |
| # 使用 wideband (wb) 模式,因为我们重采样到 16kHz | |
| score = pesq(pesq_sr, target_resampled, output_resampled, 'wb') | |
| return score | |
| except Exception as e: | |
| print(f"Warning: PESQ calculation failed for a pair. Error: {e}") | |
| return float('nan') | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Calculate SI-SNR and FAD-CLAP for audio pairs. All audio is resampled to 48000Hz.") | |
| parser.add_argument("--target_dir", '-t', required=True, type=str, help="Path to target audio directory") | |
| parser.add_argument("--output_dir", '-o', required=True, type=str, help="Path to output audio directory") | |
| parser.add_argument("--target_index", '-i', type=str, help="Index of target audio files, e.g. '11|12'") | |
| parser.add_argument("--batch_size", type=int, default=16, help="Batch size for FAD-CLAP embedding calculation.") | |
| parser.add_argument("--output_file", type=str, help="Filename to save all evaluation results.") | |
| # 测评指标开关 | |
| parser.add_argument("--calc_sisnr", action="store_true", help="Calculate Scale-Invariant SNR (SI-SNR).") | |
| parser.add_argument("--calc_pesq", action="store_true", help="Calculate Perceptual Evaluation of Speech Quality (PESQ).") | |
| parser.add_argument("--calc_aesthetics", action="store_true", help="Calculate AudioBox Aesthetics MOS.") | |
| parser.add_argument("--calc_fad_clap", default=True, action="store_true", help="Calculate Frechet Audio Distance (FAD-CLAP).") | |
| parser.add_argument("--calc_mel_snr", default=True, action="store_true", help="Calculate Multi-Mel-SNR.") # <-- Multi-Mel-SNR 开关 | |
| args = parser.parse_args() | |
| if not args.output_file: | |
| args.output_file = (args.output_dir[:-1] if args.output_dir.endswith('/') else args.output_dir) | |
| if args.target_index: | |
| args.output_file += f"_{args.target_index}" | |
| args.output_file += ".txt" | |
| # 初始化 AudioBox Aesthetics Predictor | |
| AXES_NAME = ["CE", "CU", "PC", "PQ"] | |
| LOCAL_AESTHETICS_CKPT = "/inspire/hdd/global_user/chenxie-25019/HaoQiu/EVAL_MODEL/audiobox/audiobox_aes_checkpoint.pt" | |
| try: | |
| assert args.calc_aesthetics, "AudioBox Aesthetics is not enabled" | |
| print("\nLoading AudioBox Aesthetics predictor...") | |
| aesthetics_predictor = initialize_predictor(ckpt=LOCAL_AESTHETICS_CKPT) | |
| print("AudioBox Aesthetics predictor loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading AudioBox Aesthetics predictor: {e}. Aesthetics calculation will be skipped.") | |
| aesthetics_predictor = None | |
| # 初始化文件写入 | |
| RESULTS_FILENAME = args.output_file | |
| if os.path.exists(RESULTS_FILENAME): | |
| raise Exception(f"Output file already exists: {RESULTS_FILENAME}") | |
| results_file = open(RESULTS_FILENAME, 'w', encoding='utf-8') | |
| results_file.write("--- Audio Evaluation Results ---\n") | |
| print(f"所有结果将被写入文件: {RESULTS_FILENAME}") | |
| sisnr_calculator = ScaleInvariantSignalNoiseRatio() | |
| all_target_paths = [] | |
| all_output_paths = [] | |
| all_sisnr_values = [] | |
| all_pesq_values = [] | |
| all_mel_snr_values = [] | |
| all_aesthetics_values = {axis: [] for axis in AXES_NAME} | |
| # ---------------------------------------------------- | |
| # PHASE 1: 遍历文件列表,计算 SI-SNR,收集路径 | |
| # ---------------------------------------------------- | |
| print("\n--- Calculating SI-SNR (48kHz) for each pair ---") | |
| results_file.write("\n--- Pairwise SI-SNR (dB) ---\n") | |
| TARGET_SR = 48000 | |
| def calculate_pairwise_metrics(target_path, output_path, args, results_list): | |
| if not os.path.exists(target_path) or not os.path.exists(output_path): | |
| raise Exception(f"Skipping, file not found: {target_path} -> {output_path}") | |
| target_wav = load_audio(target_path, TARGET_SR) | |
| output_wav = load_audio(output_path, TARGET_SR) | |
| if target_wav is None or output_wav is None: | |
| raise Exception(f"Skipping, waveform not loaded: {target_path} -> {output_path}") | |
| if target_wav.shape[0] != output_wav.shape[0]: | |
| print(f"Warning: shape mismatch: {target_path} -> {output_path}") | |
| if target_wav.shape[0] not in [1, 2]: | |
| raise Exception(f"Skipping, unsupported shape: {target_path} -> {output_path}") | |
| if output_wav.shape[0] not in [1, 2]: | |
| raise Exception(f"Skipping, unsupported shape: {target_path} -> {output_path}") | |
| if target_wav.shape[0] > output_wav.shape[0]: # 2 vs 1 | |
| output_wav = output_wav.repeat(2, 1) | |
| else: # 1 vs 2 | |
| output_wav = output_wav.mean(dim=0, keepdim=True) | |
| min_len = min(target_wav.shape[-1], output_wav.shape[-1]) | |
| target_wav = target_wav[..., :min_len] | |
| output_wav = output_wav[..., :min_len] | |
| if target_wav.shape[-1] == 0: | |
| raise Exception(f"Skipping, zero-length waveform: {target_path} -> {output_path}") | |
| # --- SI-SNR part --- | |
| sisnr_val = float('nan') | |
| if args.calc_sisnr: | |
| sisnr_val = sisnr_calculator(output_wav, target_wav).item() | |
| results_list['sisnr'].append(sisnr_val) | |
| # --- PESQ part --- | |
| pesq_val = float('nan') | |
| if args.calc_pesq: | |
| pesq_val = calculate_pesq(target_wav, output_wav, TARGET_SR) | |
| results_list['pesq'].append(pesq_val) | |
| # --- Multi-Mel-SNR part --- | |
| mel_snr_val = float('nan') | |
| if args.calc_mel_snr: | |
| # Multi-Mel-SNR 假设单声道输入,故对每个声道计算并平均 | |
| mel_snrs = [] | |
| for ch in range(target_wav.shape[0]): | |
| # 注意:multi_mel_snr 内部需要进行 SI-Norm,这里传入原始 wav | |
| mel_snr_val_ch = multi_mel_snr(target_wav[ch], output_wav[ch], sr=TARGET_SR) | |
| mel_snrs.append(mel_snr_val_ch) | |
| mel_snr_val = sum(mel_snrs) / len(mel_snrs) if mel_snrs else float('nan') | |
| results_list['mel_snr'].append(mel_snr_val) | |
| output_str = f"{target_path}|{output_path}" | |
| if args.calc_sisnr: | |
| output_str += f"|SI-SNR:{sisnr_val:.4f}" | |
| if args.calc_pesq: | |
| output_str += f"|PESQ:{pesq_val:.4f}" | |
| if args.calc_mel_snr: | |
| output_str += f"|Mel-SNR:{mel_snr_val:.4f}" | |
| print(output_str) | |
| all_target_paths.append(target_path) | |
| all_output_paths.append(output_path) | |
| all_pairwise_values = { | |
| 'sisnr': [], | |
| 'pesq': [], | |
| 'mel_snr': [] | |
| } | |
| print("--- Finding matching file pairs ---") | |
| pairs = find_matching_pairs(args.target_dir, args.output_dir, args.target_index) | |
| print(f"Found {len(pairs)} file pairs") | |
| for target_path, output_path in pairs: | |
| try: | |
| calculate_pairwise_metrics(target_path, output_path, args, all_pairwise_values) | |
| except Exception as e: | |
| print(f"Error processing {target_path} -> {output_path}: {e}") | |
| continue | |
| # ---------------------------------------------------- | |
| # PHASE 2: 批量计算 AudioBox Aesthetics 分数 | |
| # ---------------------------------------------------- | |
| AESTHETICS_CHUNK_SIZE = 64 | |
| if args.calc_aesthetics and aesthetics_predictor and all_output_paths: | |
| print("\n--- Calculating AudioBox Aesthetics Scores (Batch) ---") | |
| # 循环处理分块 | |
| for i in tqdm(range(0, len(all_output_paths), AESTHETICS_CHUNK_SIZE), desc=" Aesthetics chunks"): | |
| # 提取当前批次的路径 | |
| chunk_paths = all_output_paths[i:i + AESTHETICS_CHUNK_SIZE] | |
| aesthetics_input_list = [{"path": p} for p in chunk_paths] | |
| try: | |
| # 批量执行推理 (Chunked Inference) | |
| aesthetics_results = aesthetics_predictor.forward(aesthetics_input_list) | |
| # 结果匹配与收集 (与上一个回答的修正逻辑一致) | |
| num_outputs = len(chunk_paths) | |
| num_results = len(aesthetics_results) | |
| for j in range(num_outputs): | |
| if j < num_results and all(axis in aesthetics_results[j] for axis in AXES_NAME): | |
| score_dict = aesthetics_results[j] | |
| for axis in AXES_NAME: | |
| all_aesthetics_values[axis].append(score_dict[axis]) | |
| else: | |
| for axis in AXES_NAME: | |
| all_aesthetics_values[axis].append(float('nan')) | |
| except Exception as e: | |
| # 仍然捕获 OOM 或其他异常 | |
| print(f"\nError in chunk {i//AESTHETICS_CHUNK_SIZE}: {e}. Skipping chunk.") | |
| # 填充当前整个 chunk 为 NaN | |
| for axis in AXES_NAME: | |
| all_aesthetics_values[axis].extend([float('nan')] * len(chunk_paths)) | |
| # 如果是 OOM 错误,可能需要提前停止,或者尝试更小的 AESTHETICS_CHUNK_SIZE | |
| if "CUDA out of memory" in str(e): | |
| print("FATAL OOM: Please reduce AESTHETICS_CHUNK_SIZE and restart.") | |
| # 这里可以考虑 break 或 sys.exit() | |
| # 补全 Aesthetics 列表(如果未计算),确保长度与 num_pairs 匹配 | |
| if not args.calc_aesthetics or not all_output_paths: | |
| if len(all_target_paths) > 0: | |
| for axis in AXES_NAME: | |
| # 只在列表长度不一致时进行填充(避免重复填充) | |
| if len(all_aesthetics_values[axis]) < len(all_target_paths): | |
| all_aesthetics_values[axis].extend([float('nan')] * (len(all_target_paths) - len(all_aesthetics_values[axis]))) | |
| # ---------------------------------------------------- | |
| # PHASE 3: 写入逐行结果 (SI-SNR 和 Aesthetics) | |
| # ---------------------------------------------------- | |
| # 检查数据长度是否一致 | |
| num_pairs = len(all_target_paths) | |
| for metric_name, scores in all_pairwise_values.items(): | |
| if len(scores) != num_pairs: | |
| # 如果某个列表的长度不匹配,说明计算或收集过程中出现了错误 | |
| raise RuntimeError(f"内部错误:指标 '{metric_name}' 的结果数量 ({len(scores)}) 与文件对数量 ({num_pairs}) 不匹配。") | |
| # 检查 Aesthetics 指标的长度是否与文件对数量一致 | |
| if args.calc_aesthetics: | |
| for axis in AXES_NAME: | |
| scores = all_aesthetics_values[axis] | |
| if len(scores) != num_pairs: | |
| raise RuntimeError(f"内部错误:Aesthetics 指标 '{axis}' 的结果数量 ({len(scores)}) 与文件对数量 ({num_pairs}) 不匹配。") | |
| # 写入新的列头 | |
| results_file.write("\n--- Pairwise Metrics ---\n") | |
| # 动态构建列头字符串 | |
| header_metrics = f"{'Target Filename':<30}|{'Output Filename':<30}" | |
| if args.calc_sisnr: | |
| header_metrics += f"|{'SI-SNR (dB)':<15}" | |
| if args.calc_pesq: | |
| header_metrics += f"|{'PESQ':<8}" | |
| if args.calc_mel_snr: # <-- 新增 Mel-SNR 列头 | |
| header_metrics += f"|{'Mel-SNR (dB)':<15}" | |
| if args.calc_aesthetics: | |
| for axis in AXES_NAME: | |
| header_metrics += f"|{axis:<10}" # Aesthetics 的四个维度 | |
| # 写入列头分隔线 | |
| results_file.write(header_metrics + "\n") | |
| results_file.write("-" * len(header_metrics) + "\n") | |
| print("\n--- Writing results to file ---") | |
| # ... (循环 i in range(num_pairs) 不变) | |
| for i in tqdm(range(num_pairs), desc=" Writing results", ncols=100): | |
| target_filename = os.path.basename(all_target_paths[i]) | |
| output_filename = os.path.basename(all_output_paths[i]) | |
| result_line = f"{target_filename:<30}|{output_filename:<30}" | |
| if args.calc_sisnr: | |
| sisnr_item = all_pairwise_values['sisnr'][i] | |
| result_line += f"|{sisnr_item:<15.4f}" | |
| if args.calc_pesq: | |
| pesq_item = all_pairwise_values['pesq'][i] | |
| pesq_str = f"{pesq_item:<8.4f}" if not np.isnan(pesq_item) else "N/A " | |
| result_line += f"|{pesq_str}" | |
| if args.calc_mel_snr: | |
| mel_snr_item = all_pairwise_values['mel_snr'][i] | |
| mel_snr_str = f"{mel_snr_item:<15.4f}" if not np.isnan(mel_snr_item) else "N/A " | |
| result_line += f"|{mel_snr_str}" | |
| # 构造 Aesthetics 部分 | |
| aesthetics_part = "" | |
| for axis in AXES_NAME: | |
| score = all_aesthetics_values[axis][i] # 从对应的列表中取出分数 | |
| # 格式化 Aesthetics 分数 | |
| aesthetics_str = f"{score:.4f}" if not np.isnan(score) else "N/A" | |
| aesthetics_part += f"|{aesthetics_str:<10}" | |
| # 写入文件 | |
| results_file.write(result_line + aesthetics_part + "\n") | |
| # ---------------------------------------------------- | |
| # PHASE 4: 总体统计参数计算 (SI-SNR, Aesthetics) | |
| # ---------------------------------------------------- | |
| results_file.write("\n\n--- Overall Statistical Metrics ---\n") | |
| # SI-SNR 统计 | |
| if args.calc_sisnr and all_pairwise_values['sisnr']: | |
| scores = all_pairwise_values['sisnr'] | |
| if scores: | |
| avg_sisnr = statistics.mean(scores) | |
| std_sisnr = statistics.stdev(scores) if len(scores) > 1 else 0.0 | |
| # 写入平均值和标准差 | |
| results_file.write(f"SI-SNR (dB) Average: {avg_sisnr:.4f}\n") | |
| results_file.write(f"SI-SNR (dB) Std Dev: {std_sisnr:.4f}\n") | |
| else: | |
| results_file.write("No valid SI-SNR values were calculated.\n") | |
| # PESQ 统计 | |
| if args.calc_pesq and all_pairwise_values['pesq']: | |
| scores = all_pairwise_values['pesq'] | |
| valid_pesq_scores = [s for s in scores if not np.isnan(s)] | |
| if valid_pesq_scores: | |
| avg_pesq = statistics.mean(valid_pesq_scores) | |
| std_pesq = statistics.stdev(valid_pesq_scores) if len(valid_pesq_scores) > 1 else 0.0 | |
| results_file.write(f"\nPESQ Average: {avg_pesq:.4f}\n") | |
| results_file.write(f"PESQ Std Dev: {std_pesq:.4f} (from {len(valid_pesq_scores)} samples)\n") | |
| else: | |
| results_file.write("\nNo valid PESQ values were calculated.\n") | |
| # Multi-Mel-SNR 统计 | |
| if args.calc_mel_snr and all_pairwise_values['mel_snr']: | |
| scores = all_pairwise_values['mel_snr'] | |
| valid_scores = [s for s in scores if not np.isnan(s)] | |
| if valid_scores: | |
| avg_mel_snr = statistics.mean(valid_scores) | |
| std_mel_snr = statistics.stdev(valid_scores) if len(valid_scores) > 1 else 0.0 | |
| results_file.write(f"\nMulti-Mel-SNR Average: {avg_mel_snr:.4f}\n") | |
| results_file.write(f"Multi-Mel-SNR Std Dev: {std_mel_snr:.4f} (from {len(valid_scores)} samples)\n") | |
| else: | |
| results_file.write("\nNo valid Multi-Mel-SNR values were calculated.\n") | |
| # Aesthetics 统计 | |
| results_file.write("\n--- Aesthetics MOS ---\n") | |
| for axis in AXES_NAME: | |
| scores = all_aesthetics_values[axis] | |
| valid_scores = [s for s in scores if not np.isnan(s)] | |
| if valid_scores: | |
| avg_aesthetics = statistics.mean(valid_scores) | |
| std_aesthetics = statistics.stdev(valid_scores) if len(valid_scores) > 1 else 0.0 | |
| # 写入结果 | |
| results_file.write(f" {axis} (Avg/Std): {avg_aesthetics:.4f} / {std_aesthetics:.4f} (from {len(valid_scores)} samples)\n") | |
| else: | |
| results_file.write(f" {axis} (Avg/Std): N/A (No valid scores calculated)\n") | |
| # ---------------------------------------------------- | |
| # --- FAD-CLAP 计算 --- | |
| # ---------------------------------------------------- | |
| if args.calc_fad_clap: | |
| print("\n--- Calculating FAD-CLAP (48kHz) ---") | |
| if not all_target_paths: | |
| results_file.write("\nFAD-CLAP: Skipped (No valid file pairs found).\n") | |
| else: | |
| clap_model = None | |
| clap_processor = None | |
| try: | |
| results_file.write(f"\nTotal pairs for FAD-CLAP: {len(all_target_paths)}\n") | |
| print("Loading CLAP model...") | |
| LOCAL_MODEL_PATH = "/inspire/hdd/global_user/chenxie-25019/HaoQiu/EVAL_MODEL/clap-model" # 您下载的模型路径 | |
| clap_model = ClapModel.from_pretrained(LOCAL_MODEL_PATH, local_files_only=True) | |
| clap_processor = ClapProcessor.from_pretrained(LOCAL_MODEL_PATH, local_files_only=True) | |
| clap_model.eval() | |
| print("CLAP model loaded successfully.") | |
| except Exception as e: | |
| error_msg = f"Fatal Error: Could not load CLAP model. Error: {e}" | |
| print(error_msg) | |
| results_file.write(f"\nFAD-CLAP: {error_msg}\n") | |
| if clap_model and clap_processor: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| print("\nCalculating embeddings for all target files...") | |
| target_embeddings = get_clap_embeddings(all_target_paths, clap_model, clap_processor, device, args.batch_size) | |
| print("Calculating embeddings for all output files...") | |
| output_embeddings = get_clap_embeddings(all_output_paths, clap_model, clap_processor, device, args.batch_size) | |
| if target_embeddings.size > 0 and output_embeddings.size > 0: | |
| print("Calculating Frechet Audio Distance (FAD)...") | |
| fad_score = calculate_frechet_distance(target_embeddings, output_embeddings) | |
| if fad_score is not None: | |
| final_fad_output = f"\nOverall FAD-CLAP Score: {fad_score:.4f}" | |
| print(final_fad_output) | |
| results_file.write(final_fad_output + "\n") | |
| else: | |
| msg = "\nCould not calculate FAD-CLAP score." | |
| print(msg) | |
| results_file.write(f"\nFAD-CLAP: {msg}\n") | |
| else: | |
| msg = "\nCould not calculate FAD-CLAP due to issues with embedding generation." | |
| print(msg) | |
| results_file.write(f"\nFAD-CLAP: {msg}\n") | |
| # 关闭文件句柄 | |
| results_file.write("\n--- End of Report ---") | |
| results_file.close() | |
| print(f"\nDone!!!! Save the result into {RESULTS_FILENAME}。") | |
| if __name__ == "__main__": | |
| main() |