File size: 25,733 Bytes
c362e04
cad654a
8fff3a7
6d640aa
cad654a
 
 
 
 
 
 
 
8fff3a7
cad654a
 
 
 
 
 
 
 
 
 
 
 
 
 
8fff3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cad654a
 
 
 
 
 
 
 
 
 
 
 
 
 
c362e04
cad654a
 
 
 
 
 
 
 
 
 
 
8fff3a7
cad654a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fff3a7
c362e04
 
 
 
 
 
 
85dad86
c362e04
 
 
 
 
 
 
85dad86
c362e04
8fff3a7
 
3a1bba6
 
 
c362e04
 
 
 
 
 
 
 
 
 
cad654a
6d640aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cad654a
 
8fff3a7
 
 
cad654a
8fff3a7
 
 
 
 
 
 
 
cad654a
c362e04
8fff3a7
 
 
 
 
cad654a
 
 
8fff3a7
 
 
 
 
 
 
 
 
cad654a
 
 
8fff3a7
 
cad654a
 
 
 
 
 
 
 
6d640aa
8fff3a7
cad654a
 
 
 
 
 
 
 
 
 
8fff3a7
c362e04
 
 
 
 
 
 
8de1b87
 
 
 
 
 
 
 
 
c362e04
 
 
 
 
8fff3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c362e04
 
 
8fff3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c362e04
cad654a
 
 
 
8fff3a7
cad654a
 
 
8fff3a7
cad654a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fff3a7
 
 
 
 
 
 
 
cad654a
 
 
 
 
 
8fff3a7
 
 
 
 
 
 
 
 
 
 
cad654a
 
 
 
 
8fff3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cad654a
 
 
 
8fff3a7
 
 
 
cad654a
8fff3a7
 
 
 
 
 
 
 
 
 
 
cad654a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fff3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d640aa
 
 
 
 
 
 
8fff3a7
 
 
 
 
 
 
 
 
 
 
 
 
cad654a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fff3a7
 
cad654a
8fff3a7
 
 
 
 
cad654a
8fff3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cad654a
8fff3a7
 
cad654a
8fff3a7
 
 
 
 
 
 
 
 
 
 
cad654a
8fff3a7
cad654a
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
import glob
import os
import re
from pesq import pesq
import soundfile as sf
import torch
from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
import argparse
import numpy as np
import warnings
from scipy.linalg import sqrtm
from tqdm import tqdm
import torchaudio
import torchaudio.transforms as T
import statistics # <-- 新增导入,用于计算平均值和标准差
from audiobox_aesthetics.infer import initialize_predictor
warnings.filterwarnings("ignore")

try:
    from transformers import ClapModel, ClapProcessor
except ImportError:
    print("Error: The 'transformers' library is not installed.")
    print("Please install it to run FAD-CLAP calculations:")
    print("pip install torch transformers")
    exit(1)



def multi_mel_snr(reference, prediction, sr=48000):
    """Compute Multi-Mel-SNR between reference and prediction."""
    if not isinstance(reference, torch.Tensor):
        reference = torch.from_numpy(reference).float()
    if not isinstance(prediction, torch.Tensor):
        prediction = torch.from_numpy(prediction).float()
    
    # Scale-invariant normalization
    alpha = torch.dot(reference, prediction) / (torch.dot(prediction, prediction) + 1e-8)
    prediction = alpha * prediction
    
    # Three mel configurations
    configs = [
        (512, 256, 80),    # (n_fft, hop_length, n_mels)
        (1024, 512, 128),
        (2048, 1024, 192)
    ]
    
    snrs = []
    for n_fft, hop, n_mels in configs:
        mel = torchaudio.transforms.MelSpectrogram(
            sample_rate=sr, n_fft=n_fft, hop_length=hop, 
            n_mels=n_mels, f_min=0, f_max=24000, power=2.0
        )
        M_ref = mel(reference)
        M_pred = mel(prediction)
        snr = 10 * torch.log10(M_ref.pow(2).sum() / ((M_ref - M_pred).pow(2).sum() + 1e-8))
        snrs.append(snr.item())
    
    return sum(snrs) / len(snrs)

def load_audio(file_path, target_sr=48000):
    """加载音频文件,并将其重采样到目标采样率 (target_sr)。"""
    try:
        wav, samplerate = sf.read(file_path)
        
        if wav.ndim > 1:
            wav = wav.T
        else:
            wav = wav[np.newaxis, :]
            
        wav_tensor = torch.from_numpy(wav).float()

        if samplerate != target_sr:
            print(f"Warning: Resampling audio from {samplerate} to {target_sr}")
            resampler = T.Resample(orig_freq=samplerate, new_freq=target_sr)
            wav_tensor = resampler(wav_tensor)
        
        return wav_tensor
    except Exception as e:
        return None

def get_clap_embeddings(file_paths, model, processor, device, batch_size=16):
    model.to(device)
    all_embeddings = []
    
    for i in tqdm(range(0, len(file_paths), batch_size), desc="  Calculating embeddings", ncols=100, leave=False):
        batch_paths = file_paths[i:i+batch_size]
        audio_batch = []
        for path in batch_paths:
            try:
                wav_tensor = load_audio(path, target_sr=48000)
                if wav_tensor is None:
                    continue
                
                for channel in wav_tensor:
                    audio_batch.append(channel.numpy())
            except Exception:
                continue

        if not audio_batch:
            continue

        try:
            inputs = processor(audios=audio_batch, sampling_rate=48000, return_tensors="pt", padding=True)
            inputs = {key: val.to(device) for key, val in inputs.items()}
            
            with torch.no_grad():
                audio_features = model.get_audio_features(**inputs)
            
            all_embeddings.append(audio_features.cpu().numpy())
        except Exception:
            continue
            
    if not all_embeddings:
        return np.array([])
        
    return np.concatenate(all_embeddings, axis=0)

def calculate_frechet_distance(embeddings1, embeddings2):
    if embeddings1.shape[0] < 2 or embeddings2.shape[0] < 2:
        return None

    mu1, mu2 = np.mean(embeddings1, axis=0), np.mean(embeddings2, axis=0)
    sigma1, sigma2 = np.cov(embeddings1, rowvar=False), np.cov(embeddings2, rowvar=False)
    
    ssdiff = np.sum((mu1 - mu2)**2.0)
    
    try:
        covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)
    except Exception:
        return None

    if np.iscomplexobj(covmean):
        covmean = covmean.real
        
    fad_score = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fad_score

def find_matching_pairs(target_dir, output_dir, target_index):
    """
    找到target和output文件夹中的匹配文件对
    假设target文件名为: 0.flac, 1.flac, ..., 249.flac
    output文件名为: {target_id}_DT{index}.flac
    """
    pairs = []
    
    target_files = glob.glob(os.path.join(target_dir, "*.*"))
    target_files.sort()
    
    print(f"Found {len(target_files)} target files in {target_dir}")
    
    for target_file in target_files:
        target_id = os.path.splitext(os.path.basename(target_file))[0]
        
        output_pattern = os.path.join(output_dir, f"{target_id}_DT*.*")
        matching_outputs = glob.glob(output_pattern)
        if target_index is not None:
            regex = re.compile(rf"^{re.escape(target_id)}_DT({target_index})\.\w+$")
        else:
            regex = re.compile(rf"^{re.escape(target_id)}_DT\d+\.\w+$")
        matching_outputs = [f for f in matching_outputs if regex.match(os.path.basename(f))]
        matching_outputs.sort()
        
        if matching_outputs:
            print(f"Target {target_id}: found {len(matching_outputs)} output files")
            for output_file in matching_outputs:
                pairs.append((target_file, output_file))
        else:
            print(f"Target {target_id}: no matching output files found")
    
    return pairs

# --- 新增 PESQ 计算函数 ---
def calculate_pesq(target_wav, output_wav, target_sr=48000, pesq_sr=16000):
    """
    计算 PESQ 分数 (通常使用 16kHz 宽带模式)。
    target_wav 和 output_wav 必须是相同的单声道/双声道张量,且已对齐。
    """
    # 确保输入 Tensor 是单声道 (C=1)
    # WAV shape 通常是 [C, L]. 如果 C > 1, 我们将其转换为单声道。
    # 最简单的做法是取第一个声道 [0, :]
    if target_wav.ndim > 1 and target_wav.shape[0] > 1:
        # 提取第一个声道
        target_wav = target_wav[0:1, :]
    if output_wav.ndim > 1 and output_wav.shape[0] > 1:
        # 提取第一个声道
        output_wav = output_wav[0:1, :]
    # 将 Tensor 转换为 numpy 数组
    target_np = target_wav.squeeze(0).numpy()
    output_np = output_wav.squeeze(0).numpy()
    
    # 确保是单声道进行 PESQ 计算
    if target_np.ndim > 1:
        # 如果是多声道,取第一个声道或平均 (这里取第一个声道)
        target_np = target_np[0]
        output_np = output_np[0]
        
    # 重采样到 PESQ 要求的采样率 (16000 Hz)
    if target_sr != pesq_sr:
        resampler = T.Resample(orig_freq=target_sr, new_freq=pesq_sr)
        target_resampled = resampler(target_wav).squeeze(0).numpy()
        output_resampled = resampler(output_wav).squeeze(0).numpy()
    else:
        target_resampled = target_np
        output_resampled = output_np
    
    try:
        # 使用 wideband (wb) 模式,因为我们重采样到 16kHz
        score = pesq(pesq_sr, target_resampled, output_resampled, 'wb')
        return score
    except Exception as e:
        print(f"Warning: PESQ calculation failed for a pair. Error: {e}")
        return float('nan')

def main():
    parser = argparse.ArgumentParser(description="Calculate SI-SNR and FAD-CLAP for audio pairs. All audio is resampled to 48000Hz.")
    parser.add_argument("--target_dir", '-t', required=True, type=str, help="Path to target audio directory")
    parser.add_argument("--output_dir", '-o', required=True, type=str, help="Path to output audio directory")
    parser.add_argument("--target_index", '-i', type=str, help="Index of target audio files, e.g. '11|12'")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size for FAD-CLAP embedding calculation.")
    parser.add_argument("--output_file", type=str, help="Filename to save all evaluation results.")
    # 测评指标开关
    parser.add_argument("--calc_sisnr", action="store_true", help="Calculate Scale-Invariant SNR (SI-SNR).")
    parser.add_argument("--calc_pesq", action="store_true", help="Calculate Perceptual Evaluation of Speech Quality (PESQ).")
    parser.add_argument("--calc_aesthetics", action="store_true", help="Calculate AudioBox Aesthetics MOS.")
    parser.add_argument("--calc_fad_clap", default=True, action="store_true", help="Calculate Frechet Audio Distance (FAD-CLAP).")
    parser.add_argument("--calc_mel_snr", default=True, action="store_true", help="Calculate Multi-Mel-SNR.") # <-- Multi-Mel-SNR 开关
    
    args = parser.parse_args()
    
    if not args.output_file:
        args.output_file = (args.output_dir[:-1] if args.output_dir.endswith('/') else args.output_dir)
        if args.target_index:
            args.output_file += f"_{args.target_index}"
        args.output_file += ".txt"
    
    # 初始化 AudioBox Aesthetics Predictor
    AXES_NAME = ["CE", "CU", "PC", "PQ"] 
    LOCAL_AESTHETICS_CKPT = "/inspire/hdd/global_user/chenxie-25019/HaoQiu/EVAL_MODEL/audiobox/audiobox_aes_checkpoint.pt"
    try:
        assert args.calc_aesthetics, "AudioBox Aesthetics is not enabled"
        print("\nLoading AudioBox Aesthetics predictor...")
        aesthetics_predictor = initialize_predictor(ckpt=LOCAL_AESTHETICS_CKPT)
        print("AudioBox Aesthetics predictor loaded successfully.")
    except Exception as e:
        print(f"Error loading AudioBox Aesthetics predictor: {e}. Aesthetics calculation will be skipped.")
        aesthetics_predictor = None
        
    # 初始化文件写入
    RESULTS_FILENAME = args.output_file
    if os.path.exists(RESULTS_FILENAME):
        raise Exception(f"Output file already exists: {RESULTS_FILENAME}")
    results_file = open(RESULTS_FILENAME, 'w', encoding='utf-8')
    results_file.write("--- Audio Evaluation Results ---\n")
    print(f"所有结果将被写入文件: {RESULTS_FILENAME}")

    sisnr_calculator = ScaleInvariantSignalNoiseRatio()
    all_target_paths = []
    all_output_paths = []
    all_sisnr_values = []
    all_pesq_values = []
    all_mel_snr_values = []
    all_aesthetics_values = {axis: [] for axis in AXES_NAME}
    # ----------------------------------------------------
    # PHASE 1: 遍历文件列表,计算 SI-SNR,收集路径
    # ----------------------------------------------------
    
    print("\n--- Calculating SI-SNR (48kHz) for each pair ---")
    results_file.write("\n--- Pairwise SI-SNR (dB) ---\n")
    
    TARGET_SR = 48000 
    
    def calculate_pairwise_metrics(target_path, output_path, args, results_list):
        if not os.path.exists(target_path) or not os.path.exists(output_path):
            raise Exception(f"Skipping, file not found: {target_path} -> {output_path}")
        target_wav = load_audio(target_path, TARGET_SR)
        output_wav = load_audio(output_path, TARGET_SR)
        if target_wav is None or output_wav is None:
            raise Exception(f"Skipping, waveform not loaded: {target_path} -> {output_path}")
        if target_wav.shape[0] != output_wav.shape[0]:
            print(f"Warning: shape mismatch: {target_path} -> {output_path}")
            if target_wav.shape[0] not in [1, 2]:
                raise Exception(f"Skipping, unsupported shape: {target_path} -> {output_path}")
            if output_wav.shape[0] not in [1, 2]:
                raise Exception(f"Skipping, unsupported shape: {target_path} -> {output_path}")
            if target_wav.shape[0] > output_wav.shape[0]: # 2 vs 1
                output_wav = output_wav.repeat(2, 1)
            else: # 1 vs 2
                output_wav = output_wav.mean(dim=0, keepdim=True)
        min_len = min(target_wav.shape[-1], output_wav.shape[-1])
        target_wav = target_wav[..., :min_len]
        output_wav = output_wav[..., :min_len]
        if target_wav.shape[-1] == 0:
            raise Exception(f"Skipping, zero-length waveform: {target_path} -> {output_path}")
        
        # --- SI-SNR part ---
        sisnr_val = float('nan')
        if args.calc_sisnr:
            sisnr_val = sisnr_calculator(output_wav, target_wav).item()
        results_list['sisnr'].append(sisnr_val)

        # --- PESQ part ---
        pesq_val = float('nan')
        if args.calc_pesq:
            pesq_val = calculate_pesq(target_wav, output_wav, TARGET_SR)
        results_list['pesq'].append(pesq_val)
        
        # --- Multi-Mel-SNR part ---
        mel_snr_val = float('nan')
        if args.calc_mel_snr:
            # Multi-Mel-SNR 假设单声道输入,故对每个声道计算并平均
            mel_snrs = []
            for ch in range(target_wav.shape[0]):
                 # 注意:multi_mel_snr 内部需要进行 SI-Norm,这里传入原始 wav
                mel_snr_val_ch = multi_mel_snr(target_wav[ch], output_wav[ch], sr=TARGET_SR)
                mel_snrs.append(mel_snr_val_ch)
            mel_snr_val = sum(mel_snrs) / len(mel_snrs) if mel_snrs else float('nan')
        results_list['mel_snr'].append(mel_snr_val)
        
        output_str = f"{target_path}|{output_path}"
        if args.calc_sisnr:
            output_str += f"|SI-SNR:{sisnr_val:.4f}"
        if args.calc_pesq:
            output_str += f"|PESQ:{pesq_val:.4f}"
        if args.calc_mel_snr:
            output_str += f"|Mel-SNR:{mel_snr_val:.4f}"
        print(output_str)
        
        all_target_paths.append(target_path)
        all_output_paths.append(output_path)
    
    
    all_pairwise_values = {
        'sisnr': [], 
        'pesq': [], 
        'mel_snr': [] 
    }
    
    print("--- Finding matching file pairs ---")
    pairs = find_matching_pairs(args.target_dir, args.output_dir, args.target_index)
    print(f"Found {len(pairs)} file pairs")
    for target_path, output_path in pairs:
        try:
            calculate_pairwise_metrics(target_path, output_path, args, all_pairwise_values)
        except Exception as e:
            print(f"Error processing {target_path} -> {output_path}: {e}")
            continue
            
    # ----------------------------------------------------
    # PHASE 2: 批量计算 AudioBox Aesthetics 分数
    # ----------------------------------------------------
    AESTHETICS_CHUNK_SIZE = 64 
    if args.calc_aesthetics and aesthetics_predictor and all_output_paths:
        print("\n--- Calculating AudioBox Aesthetics Scores (Batch) ---")
        
        # 循环处理分块
    for i in tqdm(range(0, len(all_output_paths), AESTHETICS_CHUNK_SIZE), desc="  Aesthetics chunks"):
        
        # 提取当前批次的路径
        chunk_paths = all_output_paths[i:i + AESTHETICS_CHUNK_SIZE]
        aesthetics_input_list = [{"path": p} for p in chunk_paths]
        
        try:
            # 批量执行推理 (Chunked Inference)
            aesthetics_results = aesthetics_predictor.forward(aesthetics_input_list)
            
            # 结果匹配与收集 (与上一个回答的修正逻辑一致)
            num_outputs = len(chunk_paths)
            num_results = len(aesthetics_results)

            for j in range(num_outputs):
                if j < num_results and all(axis in aesthetics_results[j] for axis in AXES_NAME):
                    score_dict = aesthetics_results[j]
                    for axis in AXES_NAME:
                        all_aesthetics_values[axis].append(score_dict[axis])
                else:
                    for axis in AXES_NAME:
                        all_aesthetics_values[axis].append(float('nan'))
                            
        except Exception as e:
            # 仍然捕获 OOM 或其他异常
            print(f"\nError in chunk {i//AESTHETICS_CHUNK_SIZE}: {e}. Skipping chunk.")
            
            # 填充当前整个 chunk 为 NaN
            for axis in AXES_NAME:
                all_aesthetics_values[axis].extend([float('nan')] * len(chunk_paths))
            
            # 如果是 OOM 错误,可能需要提前停止,或者尝试更小的 AESTHETICS_CHUNK_SIZE
            if "CUDA out of memory" in str(e):
                print("FATAL OOM: Please reduce AESTHETICS_CHUNK_SIZE and restart.")
                # 这里可以考虑 break 或 sys.exit() 
    
    # 补全 Aesthetics 列表(如果未计算),确保长度与 num_pairs 匹配
    if not args.calc_aesthetics or not all_output_paths:
        if len(all_target_paths) > 0:
            for axis in AXES_NAME:
                # 只在列表长度不一致时进行填充(避免重复填充)
                if len(all_aesthetics_values[axis]) < len(all_target_paths):
                    all_aesthetics_values[axis].extend([float('nan')] * (len(all_target_paths) - len(all_aesthetics_values[axis])))

    # ----------------------------------------------------
    # PHASE 3: 写入逐行结果 (SI-SNR 和 Aesthetics)
    # ----------------------------------------------------
    # 检查数据长度是否一致
    num_pairs = len(all_target_paths)
    for metric_name, scores in all_pairwise_values.items():
        if len(scores) != num_pairs:
            # 如果某个列表的长度不匹配,说明计算或收集过程中出现了错误
            raise RuntimeError(f"内部错误:指标 '{metric_name}' 的结果数量 ({len(scores)}) 与文件对数量 ({num_pairs}) 不匹配。")

    # 检查 Aesthetics 指标的长度是否与文件对数量一致
    if args.calc_aesthetics:
        for axis in AXES_NAME:
            scores = all_aesthetics_values[axis]
            if len(scores) != num_pairs:
                raise RuntimeError(f"内部错误:Aesthetics 指标 '{axis}' 的结果数量 ({len(scores)}) 与文件对数量 ({num_pairs}) 不匹配。")
    
        # 写入新的列头
    results_file.write("\n--- Pairwise Metrics ---\n")
    
    # 动态构建列头字符串
    header_metrics = f"{'Target Filename':<30}|{'Output Filename':<30}"
    if args.calc_sisnr:
        header_metrics += f"|{'SI-SNR (dB)':<15}"
    if args.calc_pesq:
        header_metrics += f"|{'PESQ':<8}"
    if args.calc_mel_snr: # <-- 新增 Mel-SNR 列头
        header_metrics += f"|{'Mel-SNR (dB)':<15}"
    
    if args.calc_aesthetics:
        for axis in AXES_NAME:
            header_metrics += f"|{axis:<10}" # Aesthetics 的四个维度
            
    # 写入列头分隔线
    results_file.write(header_metrics + "\n") 
    results_file.write("-" * len(header_metrics) + "\n")
    
    print("\n--- Writing results to file ---")
    
    # ... (循环 i in range(num_pairs) 不变)
    for i in tqdm(range(num_pairs), desc="  Writing results", ncols=100):
        target_filename = os.path.basename(all_target_paths[i])
        output_filename = os.path.basename(all_output_paths[i])
        result_line = f"{target_filename:<30}|{output_filename:<30}"
        
        if args.calc_sisnr:
            sisnr_item = all_pairwise_values['sisnr'][i]
            result_line += f"|{sisnr_item:<15.4f}"
        if args.calc_pesq:
            pesq_item = all_pairwise_values['pesq'][i]
            pesq_str = f"{pesq_item:<8.4f}" if not np.isnan(pesq_item) else "N/A   "
            result_line += f"|{pesq_str}"
        if args.calc_mel_snr:
            mel_snr_item = all_pairwise_values['mel_snr'][i]
            mel_snr_str = f"{mel_snr_item:<15.4f}" if not np.isnan(mel_snr_item) else "N/A           "
            result_line += f"|{mel_snr_str}"
        
        # 构造 Aesthetics 部分
        aesthetics_part = ""
        for axis in AXES_NAME:
            score = all_aesthetics_values[axis][i] # 从对应的列表中取出分数
            
            # 格式化 Aesthetics 分数
            aesthetics_str = f"{score:.4f}" if not np.isnan(score) else "N/A"
            aesthetics_part += f"|{aesthetics_str:<10}"

        # 写入文件
        results_file.write(result_line + aesthetics_part + "\n")

    # ----------------------------------------------------
    # PHASE 4: 总体统计参数计算 (SI-SNR, Aesthetics)
    # ----------------------------------------------------

    results_file.write("\n\n--- Overall Statistical Metrics ---\n")
     #  SI-SNR 统计
    if args.calc_sisnr and all_pairwise_values['sisnr']:
        scores = all_pairwise_values['sisnr']
        if scores:
            avg_sisnr = statistics.mean(scores)
            std_sisnr = statistics.stdev(scores) if len(scores) > 1 else 0.0
            
            # 写入平均值和标准差
            results_file.write(f"SI-SNR (dB) Average: {avg_sisnr:.4f}\n")
            results_file.write(f"SI-SNR (dB) Std Dev: {std_sisnr:.4f}\n")
        else:
            results_file.write("No valid SI-SNR values were calculated.\n")

    # PESQ 统计
    if args.calc_pesq and all_pairwise_values['pesq']:
        scores = all_pairwise_values['pesq']
        valid_pesq_scores = [s for s in scores if not np.isnan(s)]
        if valid_pesq_scores:
            avg_pesq = statistics.mean(valid_pesq_scores)
            std_pesq = statistics.stdev(valid_pesq_scores) if len(valid_pesq_scores) > 1 else 0.0
            results_file.write(f"\nPESQ Average: {avg_pesq:.4f}\n")
            results_file.write(f"PESQ Std Dev: {std_pesq:.4f} (from {len(valid_pesq_scores)} samples)\n")
        else:
            results_file.write("\nNo valid PESQ values were calculated.\n")
    
    #  Multi-Mel-SNR 统计 
    if args.calc_mel_snr and all_pairwise_values['mel_snr']:
        scores = all_pairwise_values['mel_snr']
        valid_scores = [s for s in scores if not np.isnan(s)]
        if valid_scores:
            avg_mel_snr = statistics.mean(valid_scores)
            std_mel_snr = statistics.stdev(valid_scores) if len(valid_scores) > 1 else 0.0
            results_file.write(f"\nMulti-Mel-SNR Average: {avg_mel_snr:.4f}\n")
            results_file.write(f"Multi-Mel-SNR Std Dev: {std_mel_snr:.4f} (from {len(valid_scores)} samples)\n")
        else:
            results_file.write("\nNo valid Multi-Mel-SNR values were calculated.\n")        
    
    # Aesthetics 统计
    results_file.write("\n--- Aesthetics MOS ---\n")
    for axis in AXES_NAME:
        scores = all_aesthetics_values[axis]
        valid_scores = [s for s in scores if not np.isnan(s)]
        
        if valid_scores:
            avg_aesthetics = statistics.mean(valid_scores)
            std_aesthetics = statistics.stdev(valid_scores) if len(valid_scores) > 1 else 0.0
            
            # 写入结果
            results_file.write(f"  {axis} (Avg/Std): {avg_aesthetics:.4f} / {std_aesthetics:.4f} (from {len(valid_scores)} samples)\n")
        else:
            results_file.write(f"  {axis} (Avg/Std): N/A (No valid scores calculated)\n")
    
    # ----------------------------------------------------
    # --- FAD-CLAP 计算 ---
    # ----------------------------------------------------
    if args.calc_fad_clap:
        print("\n--- Calculating FAD-CLAP (48kHz) ---")
        
        if not all_target_paths:
            results_file.write("\nFAD-CLAP: Skipped (No valid file pairs found).\n")
        else:
            clap_model = None
            clap_processor = None

        try:
            results_file.write(f"\nTotal pairs for FAD-CLAP: {len(all_target_paths)}\n")
            print("Loading CLAP model...")
            LOCAL_MODEL_PATH = "/inspire/hdd/global_user/chenxie-25019/HaoQiu/EVAL_MODEL/clap-model"  # 您下载的模型路径
            clap_model = ClapModel.from_pretrained(LOCAL_MODEL_PATH, local_files_only=True)
            clap_processor = ClapProcessor.from_pretrained(LOCAL_MODEL_PATH, local_files_only=True)
            clap_model.eval()
            print("CLAP model loaded successfully.")
            
        except Exception as e:
            error_msg = f"Fatal Error: Could not load CLAP model. Error: {e}"
            print(error_msg)
            results_file.write(f"\nFAD-CLAP: {error_msg}\n")
    if clap_model and clap_processor:            
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")

        print("\nCalculating embeddings for all target files...")
        target_embeddings = get_clap_embeddings(all_target_paths, clap_model, clap_processor, device, args.batch_size)

        print("Calculating embeddings for all output files...")
        output_embeddings = get_clap_embeddings(all_output_paths, clap_model, clap_processor, device, args.batch_size)

        if target_embeddings.size > 0 and output_embeddings.size > 0:
            print("Calculating Frechet Audio Distance (FAD)...")
            fad_score = calculate_frechet_distance(target_embeddings, output_embeddings)
            if fad_score is not None:
                final_fad_output = f"\nOverall FAD-CLAP Score: {fad_score:.4f}"
                print(final_fad_output)
                results_file.write(final_fad_output + "\n")
            else:
                msg = "\nCould not calculate FAD-CLAP score."
                print(msg)
                results_file.write(f"\nFAD-CLAP: {msg}\n")
        else:
            msg = "\nCould not calculate FAD-CLAP due to issues with embedding generation."
            print(msg)
            results_file.write(f"\nFAD-CLAP: {msg}\n")

    # 关闭文件句柄
    results_file.write("\n--- End of Report ---")
    results_file.close()
    print(f"\nDone!!!! Save the result into {RESULTS_FILENAME}。")

if __name__ == "__main__":
    main()