Spaces:
Running on Zero
Running on Zero
File size: 25,733 Bytes
c362e04 cad654a 8fff3a7 6d640aa cad654a 8fff3a7 cad654a 8fff3a7 cad654a c362e04 cad654a 8fff3a7 cad654a 8fff3a7 c362e04 85dad86 c362e04 85dad86 c362e04 8fff3a7 3a1bba6 c362e04 cad654a 6d640aa cad654a 8fff3a7 cad654a 8fff3a7 cad654a c362e04 8fff3a7 cad654a 8fff3a7 cad654a 8fff3a7 cad654a 6d640aa 8fff3a7 cad654a 8fff3a7 c362e04 8de1b87 c362e04 8fff3a7 c362e04 8fff3a7 c362e04 cad654a 8fff3a7 cad654a 8fff3a7 cad654a 8fff3a7 cad654a 8fff3a7 cad654a 8fff3a7 cad654a 8fff3a7 cad654a 8fff3a7 cad654a 8fff3a7 6d640aa 8fff3a7 cad654a 8fff3a7 cad654a 8fff3a7 cad654a 8fff3a7 cad654a 8fff3a7 cad654a 8fff3a7 cad654a 8fff3a7 cad654a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 | import glob
import os
import re
from pesq import pesq
import soundfile as sf
import torch
from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
import argparse
import numpy as np
import warnings
from scipy.linalg import sqrtm
from tqdm import tqdm
import torchaudio
import torchaudio.transforms as T
import statistics # <-- 新增导入,用于计算平均值和标准差
from audiobox_aesthetics.infer import initialize_predictor
warnings.filterwarnings("ignore")
try:
from transformers import ClapModel, ClapProcessor
except ImportError:
print("Error: The 'transformers' library is not installed.")
print("Please install it to run FAD-CLAP calculations:")
print("pip install torch transformers")
exit(1)
def multi_mel_snr(reference, prediction, sr=48000):
"""Compute Multi-Mel-SNR between reference and prediction."""
if not isinstance(reference, torch.Tensor):
reference = torch.from_numpy(reference).float()
if not isinstance(prediction, torch.Tensor):
prediction = torch.from_numpy(prediction).float()
# Scale-invariant normalization
alpha = torch.dot(reference, prediction) / (torch.dot(prediction, prediction) + 1e-8)
prediction = alpha * prediction
# Three mel configurations
configs = [
(512, 256, 80), # (n_fft, hop_length, n_mels)
(1024, 512, 128),
(2048, 1024, 192)
]
snrs = []
for n_fft, hop, n_mels in configs:
mel = torchaudio.transforms.MelSpectrogram(
sample_rate=sr, n_fft=n_fft, hop_length=hop,
n_mels=n_mels, f_min=0, f_max=24000, power=2.0
)
M_ref = mel(reference)
M_pred = mel(prediction)
snr = 10 * torch.log10(M_ref.pow(2).sum() / ((M_ref - M_pred).pow(2).sum() + 1e-8))
snrs.append(snr.item())
return sum(snrs) / len(snrs)
def load_audio(file_path, target_sr=48000):
"""加载音频文件,并将其重采样到目标采样率 (target_sr)。"""
try:
wav, samplerate = sf.read(file_path)
if wav.ndim > 1:
wav = wav.T
else:
wav = wav[np.newaxis, :]
wav_tensor = torch.from_numpy(wav).float()
if samplerate != target_sr:
print(f"Warning: Resampling audio from {samplerate} to {target_sr}")
resampler = T.Resample(orig_freq=samplerate, new_freq=target_sr)
wav_tensor = resampler(wav_tensor)
return wav_tensor
except Exception as e:
return None
def get_clap_embeddings(file_paths, model, processor, device, batch_size=16):
model.to(device)
all_embeddings = []
for i in tqdm(range(0, len(file_paths), batch_size), desc=" Calculating embeddings", ncols=100, leave=False):
batch_paths = file_paths[i:i+batch_size]
audio_batch = []
for path in batch_paths:
try:
wav_tensor = load_audio(path, target_sr=48000)
if wav_tensor is None:
continue
for channel in wav_tensor:
audio_batch.append(channel.numpy())
except Exception:
continue
if not audio_batch:
continue
try:
inputs = processor(audios=audio_batch, sampling_rate=48000, return_tensors="pt", padding=True)
inputs = {key: val.to(device) for key, val in inputs.items()}
with torch.no_grad():
audio_features = model.get_audio_features(**inputs)
all_embeddings.append(audio_features.cpu().numpy())
except Exception:
continue
if not all_embeddings:
return np.array([])
return np.concatenate(all_embeddings, axis=0)
def calculate_frechet_distance(embeddings1, embeddings2):
if embeddings1.shape[0] < 2 or embeddings2.shape[0] < 2:
return None
mu1, mu2 = np.mean(embeddings1, axis=0), np.mean(embeddings2, axis=0)
sigma1, sigma2 = np.cov(embeddings1, rowvar=False), np.cov(embeddings2, rowvar=False)
ssdiff = np.sum((mu1 - mu2)**2.0)
try:
covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)
except Exception:
return None
if np.iscomplexobj(covmean):
covmean = covmean.real
fad_score = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
return fad_score
def find_matching_pairs(target_dir, output_dir, target_index):
"""
找到target和output文件夹中的匹配文件对
假设target文件名为: 0.flac, 1.flac, ..., 249.flac
output文件名为: {target_id}_DT{index}.flac
"""
pairs = []
target_files = glob.glob(os.path.join(target_dir, "*.*"))
target_files.sort()
print(f"Found {len(target_files)} target files in {target_dir}")
for target_file in target_files:
target_id = os.path.splitext(os.path.basename(target_file))[0]
output_pattern = os.path.join(output_dir, f"{target_id}_DT*.*")
matching_outputs = glob.glob(output_pattern)
if target_index is not None:
regex = re.compile(rf"^{re.escape(target_id)}_DT({target_index})\.\w+$")
else:
regex = re.compile(rf"^{re.escape(target_id)}_DT\d+\.\w+$")
matching_outputs = [f for f in matching_outputs if regex.match(os.path.basename(f))]
matching_outputs.sort()
if matching_outputs:
print(f"Target {target_id}: found {len(matching_outputs)} output files")
for output_file in matching_outputs:
pairs.append((target_file, output_file))
else:
print(f"Target {target_id}: no matching output files found")
return pairs
# --- 新增 PESQ 计算函数 ---
def calculate_pesq(target_wav, output_wav, target_sr=48000, pesq_sr=16000):
"""
计算 PESQ 分数 (通常使用 16kHz 宽带模式)。
target_wav 和 output_wav 必须是相同的单声道/双声道张量,且已对齐。
"""
# 确保输入 Tensor 是单声道 (C=1)
# WAV shape 通常是 [C, L]. 如果 C > 1, 我们将其转换为单声道。
# 最简单的做法是取第一个声道 [0, :]
if target_wav.ndim > 1 and target_wav.shape[0] > 1:
# 提取第一个声道
target_wav = target_wav[0:1, :]
if output_wav.ndim > 1 and output_wav.shape[0] > 1:
# 提取第一个声道
output_wav = output_wav[0:1, :]
# 将 Tensor 转换为 numpy 数组
target_np = target_wav.squeeze(0).numpy()
output_np = output_wav.squeeze(0).numpy()
# 确保是单声道进行 PESQ 计算
if target_np.ndim > 1:
# 如果是多声道,取第一个声道或平均 (这里取第一个声道)
target_np = target_np[0]
output_np = output_np[0]
# 重采样到 PESQ 要求的采样率 (16000 Hz)
if target_sr != pesq_sr:
resampler = T.Resample(orig_freq=target_sr, new_freq=pesq_sr)
target_resampled = resampler(target_wav).squeeze(0).numpy()
output_resampled = resampler(output_wav).squeeze(0).numpy()
else:
target_resampled = target_np
output_resampled = output_np
try:
# 使用 wideband (wb) 模式,因为我们重采样到 16kHz
score = pesq(pesq_sr, target_resampled, output_resampled, 'wb')
return score
except Exception as e:
print(f"Warning: PESQ calculation failed for a pair. Error: {e}")
return float('nan')
def main():
parser = argparse.ArgumentParser(description="Calculate SI-SNR and FAD-CLAP for audio pairs. All audio is resampled to 48000Hz.")
parser.add_argument("--target_dir", '-t', required=True, type=str, help="Path to target audio directory")
parser.add_argument("--output_dir", '-o', required=True, type=str, help="Path to output audio directory")
parser.add_argument("--target_index", '-i', type=str, help="Index of target audio files, e.g. '11|12'")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for FAD-CLAP embedding calculation.")
parser.add_argument("--output_file", type=str, help="Filename to save all evaluation results.")
# 测评指标开关
parser.add_argument("--calc_sisnr", action="store_true", help="Calculate Scale-Invariant SNR (SI-SNR).")
parser.add_argument("--calc_pesq", action="store_true", help="Calculate Perceptual Evaluation of Speech Quality (PESQ).")
parser.add_argument("--calc_aesthetics", action="store_true", help="Calculate AudioBox Aesthetics MOS.")
parser.add_argument("--calc_fad_clap", default=True, action="store_true", help="Calculate Frechet Audio Distance (FAD-CLAP).")
parser.add_argument("--calc_mel_snr", default=True, action="store_true", help="Calculate Multi-Mel-SNR.") # <-- Multi-Mel-SNR 开关
args = parser.parse_args()
if not args.output_file:
args.output_file = (args.output_dir[:-1] if args.output_dir.endswith('/') else args.output_dir)
if args.target_index:
args.output_file += f"_{args.target_index}"
args.output_file += ".txt"
# 初始化 AudioBox Aesthetics Predictor
AXES_NAME = ["CE", "CU", "PC", "PQ"]
LOCAL_AESTHETICS_CKPT = "/inspire/hdd/global_user/chenxie-25019/HaoQiu/EVAL_MODEL/audiobox/audiobox_aes_checkpoint.pt"
try:
assert args.calc_aesthetics, "AudioBox Aesthetics is not enabled"
print("\nLoading AudioBox Aesthetics predictor...")
aesthetics_predictor = initialize_predictor(ckpt=LOCAL_AESTHETICS_CKPT)
print("AudioBox Aesthetics predictor loaded successfully.")
except Exception as e:
print(f"Error loading AudioBox Aesthetics predictor: {e}. Aesthetics calculation will be skipped.")
aesthetics_predictor = None
# 初始化文件写入
RESULTS_FILENAME = args.output_file
if os.path.exists(RESULTS_FILENAME):
raise Exception(f"Output file already exists: {RESULTS_FILENAME}")
results_file = open(RESULTS_FILENAME, 'w', encoding='utf-8')
results_file.write("--- Audio Evaluation Results ---\n")
print(f"所有结果将被写入文件: {RESULTS_FILENAME}")
sisnr_calculator = ScaleInvariantSignalNoiseRatio()
all_target_paths = []
all_output_paths = []
all_sisnr_values = []
all_pesq_values = []
all_mel_snr_values = []
all_aesthetics_values = {axis: [] for axis in AXES_NAME}
# ----------------------------------------------------
# PHASE 1: 遍历文件列表,计算 SI-SNR,收集路径
# ----------------------------------------------------
print("\n--- Calculating SI-SNR (48kHz) for each pair ---")
results_file.write("\n--- Pairwise SI-SNR (dB) ---\n")
TARGET_SR = 48000
def calculate_pairwise_metrics(target_path, output_path, args, results_list):
if not os.path.exists(target_path) or not os.path.exists(output_path):
raise Exception(f"Skipping, file not found: {target_path} -> {output_path}")
target_wav = load_audio(target_path, TARGET_SR)
output_wav = load_audio(output_path, TARGET_SR)
if target_wav is None or output_wav is None:
raise Exception(f"Skipping, waveform not loaded: {target_path} -> {output_path}")
if target_wav.shape[0] != output_wav.shape[0]:
print(f"Warning: shape mismatch: {target_path} -> {output_path}")
if target_wav.shape[0] not in [1, 2]:
raise Exception(f"Skipping, unsupported shape: {target_path} -> {output_path}")
if output_wav.shape[0] not in [1, 2]:
raise Exception(f"Skipping, unsupported shape: {target_path} -> {output_path}")
if target_wav.shape[0] > output_wav.shape[0]: # 2 vs 1
output_wav = output_wav.repeat(2, 1)
else: # 1 vs 2
output_wav = output_wav.mean(dim=0, keepdim=True)
min_len = min(target_wav.shape[-1], output_wav.shape[-1])
target_wav = target_wav[..., :min_len]
output_wav = output_wav[..., :min_len]
if target_wav.shape[-1] == 0:
raise Exception(f"Skipping, zero-length waveform: {target_path} -> {output_path}")
# --- SI-SNR part ---
sisnr_val = float('nan')
if args.calc_sisnr:
sisnr_val = sisnr_calculator(output_wav, target_wav).item()
results_list['sisnr'].append(sisnr_val)
# --- PESQ part ---
pesq_val = float('nan')
if args.calc_pesq:
pesq_val = calculate_pesq(target_wav, output_wav, TARGET_SR)
results_list['pesq'].append(pesq_val)
# --- Multi-Mel-SNR part ---
mel_snr_val = float('nan')
if args.calc_mel_snr:
# Multi-Mel-SNR 假设单声道输入,故对每个声道计算并平均
mel_snrs = []
for ch in range(target_wav.shape[0]):
# 注意:multi_mel_snr 内部需要进行 SI-Norm,这里传入原始 wav
mel_snr_val_ch = multi_mel_snr(target_wav[ch], output_wav[ch], sr=TARGET_SR)
mel_snrs.append(mel_snr_val_ch)
mel_snr_val = sum(mel_snrs) / len(mel_snrs) if mel_snrs else float('nan')
results_list['mel_snr'].append(mel_snr_val)
output_str = f"{target_path}|{output_path}"
if args.calc_sisnr:
output_str += f"|SI-SNR:{sisnr_val:.4f}"
if args.calc_pesq:
output_str += f"|PESQ:{pesq_val:.4f}"
if args.calc_mel_snr:
output_str += f"|Mel-SNR:{mel_snr_val:.4f}"
print(output_str)
all_target_paths.append(target_path)
all_output_paths.append(output_path)
all_pairwise_values = {
'sisnr': [],
'pesq': [],
'mel_snr': []
}
print("--- Finding matching file pairs ---")
pairs = find_matching_pairs(args.target_dir, args.output_dir, args.target_index)
print(f"Found {len(pairs)} file pairs")
for target_path, output_path in pairs:
try:
calculate_pairwise_metrics(target_path, output_path, args, all_pairwise_values)
except Exception as e:
print(f"Error processing {target_path} -> {output_path}: {e}")
continue
# ----------------------------------------------------
# PHASE 2: 批量计算 AudioBox Aesthetics 分数
# ----------------------------------------------------
AESTHETICS_CHUNK_SIZE = 64
if args.calc_aesthetics and aesthetics_predictor and all_output_paths:
print("\n--- Calculating AudioBox Aesthetics Scores (Batch) ---")
# 循环处理分块
for i in tqdm(range(0, len(all_output_paths), AESTHETICS_CHUNK_SIZE), desc=" Aesthetics chunks"):
# 提取当前批次的路径
chunk_paths = all_output_paths[i:i + AESTHETICS_CHUNK_SIZE]
aesthetics_input_list = [{"path": p} for p in chunk_paths]
try:
# 批量执行推理 (Chunked Inference)
aesthetics_results = aesthetics_predictor.forward(aesthetics_input_list)
# 结果匹配与收集 (与上一个回答的修正逻辑一致)
num_outputs = len(chunk_paths)
num_results = len(aesthetics_results)
for j in range(num_outputs):
if j < num_results and all(axis in aesthetics_results[j] for axis in AXES_NAME):
score_dict = aesthetics_results[j]
for axis in AXES_NAME:
all_aesthetics_values[axis].append(score_dict[axis])
else:
for axis in AXES_NAME:
all_aesthetics_values[axis].append(float('nan'))
except Exception as e:
# 仍然捕获 OOM 或其他异常
print(f"\nError in chunk {i//AESTHETICS_CHUNK_SIZE}: {e}. Skipping chunk.")
# 填充当前整个 chunk 为 NaN
for axis in AXES_NAME:
all_aesthetics_values[axis].extend([float('nan')] * len(chunk_paths))
# 如果是 OOM 错误,可能需要提前停止,或者尝试更小的 AESTHETICS_CHUNK_SIZE
if "CUDA out of memory" in str(e):
print("FATAL OOM: Please reduce AESTHETICS_CHUNK_SIZE and restart.")
# 这里可以考虑 break 或 sys.exit()
# 补全 Aesthetics 列表(如果未计算),确保长度与 num_pairs 匹配
if not args.calc_aesthetics or not all_output_paths:
if len(all_target_paths) > 0:
for axis in AXES_NAME:
# 只在列表长度不一致时进行填充(避免重复填充)
if len(all_aesthetics_values[axis]) < len(all_target_paths):
all_aesthetics_values[axis].extend([float('nan')] * (len(all_target_paths) - len(all_aesthetics_values[axis])))
# ----------------------------------------------------
# PHASE 3: 写入逐行结果 (SI-SNR 和 Aesthetics)
# ----------------------------------------------------
# 检查数据长度是否一致
num_pairs = len(all_target_paths)
for metric_name, scores in all_pairwise_values.items():
if len(scores) != num_pairs:
# 如果某个列表的长度不匹配,说明计算或收集过程中出现了错误
raise RuntimeError(f"内部错误:指标 '{metric_name}' 的结果数量 ({len(scores)}) 与文件对数量 ({num_pairs}) 不匹配。")
# 检查 Aesthetics 指标的长度是否与文件对数量一致
if args.calc_aesthetics:
for axis in AXES_NAME:
scores = all_aesthetics_values[axis]
if len(scores) != num_pairs:
raise RuntimeError(f"内部错误:Aesthetics 指标 '{axis}' 的结果数量 ({len(scores)}) 与文件对数量 ({num_pairs}) 不匹配。")
# 写入新的列头
results_file.write("\n--- Pairwise Metrics ---\n")
# 动态构建列头字符串
header_metrics = f"{'Target Filename':<30}|{'Output Filename':<30}"
if args.calc_sisnr:
header_metrics += f"|{'SI-SNR (dB)':<15}"
if args.calc_pesq:
header_metrics += f"|{'PESQ':<8}"
if args.calc_mel_snr: # <-- 新增 Mel-SNR 列头
header_metrics += f"|{'Mel-SNR (dB)':<15}"
if args.calc_aesthetics:
for axis in AXES_NAME:
header_metrics += f"|{axis:<10}" # Aesthetics 的四个维度
# 写入列头分隔线
results_file.write(header_metrics + "\n")
results_file.write("-" * len(header_metrics) + "\n")
print("\n--- Writing results to file ---")
# ... (循环 i in range(num_pairs) 不变)
for i in tqdm(range(num_pairs), desc=" Writing results", ncols=100):
target_filename = os.path.basename(all_target_paths[i])
output_filename = os.path.basename(all_output_paths[i])
result_line = f"{target_filename:<30}|{output_filename:<30}"
if args.calc_sisnr:
sisnr_item = all_pairwise_values['sisnr'][i]
result_line += f"|{sisnr_item:<15.4f}"
if args.calc_pesq:
pesq_item = all_pairwise_values['pesq'][i]
pesq_str = f"{pesq_item:<8.4f}" if not np.isnan(pesq_item) else "N/A "
result_line += f"|{pesq_str}"
if args.calc_mel_snr:
mel_snr_item = all_pairwise_values['mel_snr'][i]
mel_snr_str = f"{mel_snr_item:<15.4f}" if not np.isnan(mel_snr_item) else "N/A "
result_line += f"|{mel_snr_str}"
# 构造 Aesthetics 部分
aesthetics_part = ""
for axis in AXES_NAME:
score = all_aesthetics_values[axis][i] # 从对应的列表中取出分数
# 格式化 Aesthetics 分数
aesthetics_str = f"{score:.4f}" if not np.isnan(score) else "N/A"
aesthetics_part += f"|{aesthetics_str:<10}"
# 写入文件
results_file.write(result_line + aesthetics_part + "\n")
# ----------------------------------------------------
# PHASE 4: 总体统计参数计算 (SI-SNR, Aesthetics)
# ----------------------------------------------------
results_file.write("\n\n--- Overall Statistical Metrics ---\n")
# SI-SNR 统计
if args.calc_sisnr and all_pairwise_values['sisnr']:
scores = all_pairwise_values['sisnr']
if scores:
avg_sisnr = statistics.mean(scores)
std_sisnr = statistics.stdev(scores) if len(scores) > 1 else 0.0
# 写入平均值和标准差
results_file.write(f"SI-SNR (dB) Average: {avg_sisnr:.4f}\n")
results_file.write(f"SI-SNR (dB) Std Dev: {std_sisnr:.4f}\n")
else:
results_file.write("No valid SI-SNR values were calculated.\n")
# PESQ 统计
if args.calc_pesq and all_pairwise_values['pesq']:
scores = all_pairwise_values['pesq']
valid_pesq_scores = [s for s in scores if not np.isnan(s)]
if valid_pesq_scores:
avg_pesq = statistics.mean(valid_pesq_scores)
std_pesq = statistics.stdev(valid_pesq_scores) if len(valid_pesq_scores) > 1 else 0.0
results_file.write(f"\nPESQ Average: {avg_pesq:.4f}\n")
results_file.write(f"PESQ Std Dev: {std_pesq:.4f} (from {len(valid_pesq_scores)} samples)\n")
else:
results_file.write("\nNo valid PESQ values were calculated.\n")
# Multi-Mel-SNR 统计
if args.calc_mel_snr and all_pairwise_values['mel_snr']:
scores = all_pairwise_values['mel_snr']
valid_scores = [s for s in scores if not np.isnan(s)]
if valid_scores:
avg_mel_snr = statistics.mean(valid_scores)
std_mel_snr = statistics.stdev(valid_scores) if len(valid_scores) > 1 else 0.0
results_file.write(f"\nMulti-Mel-SNR Average: {avg_mel_snr:.4f}\n")
results_file.write(f"Multi-Mel-SNR Std Dev: {std_mel_snr:.4f} (from {len(valid_scores)} samples)\n")
else:
results_file.write("\nNo valid Multi-Mel-SNR values were calculated.\n")
# Aesthetics 统计
results_file.write("\n--- Aesthetics MOS ---\n")
for axis in AXES_NAME:
scores = all_aesthetics_values[axis]
valid_scores = [s for s in scores if not np.isnan(s)]
if valid_scores:
avg_aesthetics = statistics.mean(valid_scores)
std_aesthetics = statistics.stdev(valid_scores) if len(valid_scores) > 1 else 0.0
# 写入结果
results_file.write(f" {axis} (Avg/Std): {avg_aesthetics:.4f} / {std_aesthetics:.4f} (from {len(valid_scores)} samples)\n")
else:
results_file.write(f" {axis} (Avg/Std): N/A (No valid scores calculated)\n")
# ----------------------------------------------------
# --- FAD-CLAP 计算 ---
# ----------------------------------------------------
if args.calc_fad_clap:
print("\n--- Calculating FAD-CLAP (48kHz) ---")
if not all_target_paths:
results_file.write("\nFAD-CLAP: Skipped (No valid file pairs found).\n")
else:
clap_model = None
clap_processor = None
try:
results_file.write(f"\nTotal pairs for FAD-CLAP: {len(all_target_paths)}\n")
print("Loading CLAP model...")
LOCAL_MODEL_PATH = "/inspire/hdd/global_user/chenxie-25019/HaoQiu/EVAL_MODEL/clap-model" # 您下载的模型路径
clap_model = ClapModel.from_pretrained(LOCAL_MODEL_PATH, local_files_only=True)
clap_processor = ClapProcessor.from_pretrained(LOCAL_MODEL_PATH, local_files_only=True)
clap_model.eval()
print("CLAP model loaded successfully.")
except Exception as e:
error_msg = f"Fatal Error: Could not load CLAP model. Error: {e}"
print(error_msg)
results_file.write(f"\nFAD-CLAP: {error_msg}\n")
if clap_model and clap_processor:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("\nCalculating embeddings for all target files...")
target_embeddings = get_clap_embeddings(all_target_paths, clap_model, clap_processor, device, args.batch_size)
print("Calculating embeddings for all output files...")
output_embeddings = get_clap_embeddings(all_output_paths, clap_model, clap_processor, device, args.batch_size)
if target_embeddings.size > 0 and output_embeddings.size > 0:
print("Calculating Frechet Audio Distance (FAD)...")
fad_score = calculate_frechet_distance(target_embeddings, output_embeddings)
if fad_score is not None:
final_fad_output = f"\nOverall FAD-CLAP Score: {fad_score:.4f}"
print(final_fad_output)
results_file.write(final_fad_output + "\n")
else:
msg = "\nCould not calculate FAD-CLAP score."
print(msg)
results_file.write(f"\nFAD-CLAP: {msg}\n")
else:
msg = "\nCould not calculate FAD-CLAP due to issues with embedding generation."
print(msg)
results_file.write(f"\nFAD-CLAP: {msg}\n")
# 关闭文件句柄
results_file.write("\n--- End of Report ---")
results_file.close()
print(f"\nDone!!!! Save the result into {RESULTS_FILENAME}。")
if __name__ == "__main__":
main() |