File size: 3,925 Bytes
dbbd709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import csv
import glob
from tqdm import tqdm
import torch
import torchaudio
from torchmetrics.audio import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio


def calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths):
    """
    计算叠加的音频与原始音频之间的 SDR 和 SI-SDR。

    参数:
    - original_audio_path: str, 原始音频文件路径。
    - separated_audio_paths: List[str], 分割后的音频片段文件路径列表。

    返回:
    - sdr: float, SDR 值。
    - sisdr: float, SI-SDR 值。
    """
    # 加载原始音频
    original_waveform, sample_rate = torchaudio.load(original_audio_path)

    # 初始化叠加的音频波形
    combined_waveform = None

    # 加载并叠加分割的音频片段
    for path in separated_audio_paths:
        separated_waveform, _ = torchaudio.load(path)

        # 对齐片段长度
        min_length = min(original_waveform.size(1), separated_waveform.size(1))
        separated_waveform = separated_waveform[:, :min_length]

        # 初始化或叠加音频
        if combined_waveform is None:
            combined_waveform = separated_waveform
        else:
            combined_waveform = combined_waveform[:, :min_length] + separated_waveform

    # 确保合并后的音频和原始音频的长度一致
    min_length = min(original_waveform.size(1), combined_waveform.size(1))
    original_waveform = original_waveform[:, :min_length]
    combined_waveform = combined_waveform[:, :min_length]

    # 计算 SI-SDR
    sisdr_metric = ScaleInvariantSignalDistortionRatio()
    sisdr = sisdr_metric(combined_waveform, original_waveform).item()

    # 计算 SDR
    sdr_metric = SignalDistortionRatio()
    sdr = sdr_metric(combined_waveform, original_waveform).item()

    # print(f"SI-SDR between original and combined audio: {sisdr} dB")
    # print(f"SDR between original and combined audio: {sdr} dB")

    return sdr, sisdr


if __name__ == "__main__":
    # 示例: 指定原始音频和分割后的音频片段路径
    # original_audio_path = "path_to_original_audio.wav"
    # separated_audio_paths = [
    #     "path_to_segment_1.wav",
    #     "path_to_segment_2.wav",
    #     "path_to_segment_3.wav",
    # ]

    # # 计算 SDR 和 SI-SDR
    # sdr, sisdr = calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths)
    
    
    dset = 'balanced_train_segments'
    # dset = 'eval_segments'
    
    src_data_root = r'/data/sound/audioset/audios_32k'
    sep_data_root = r'data_engine_infer/audioset_separation_child_label'
    
    writer = csv.writer(open(os.path.join(sep_data_root, dset + '.csv'), 'w'))
    writer.writerow(['video', 'sdr', 'sisdr'])
    for video_path in tqdm(glob.glob(os.path.join(sep_data_root, dset, '*'))):
        video = video_path.split('/')[-1]
        original_audio_path = os.path.join(src_data_root, dset, video + '.wav')
        separated_audio_paths = glob.glob(video_path + '/*')
        sdr, sisdr = calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths)
        writer.writerow([video, f'{sdr:.3f}', f'{sisdr:.3f}'])


    # dset = 'unbalanced_train_segments'
    
    # src_data_root = r'/data/sound/audioset/audios_32k'
    # sep_data_root = r'data_engine_infer/audioset_separation_child_label'
    
    # writer = csv.writer(open(os.path.join(sep_data_root, dset + '.csv'), 'w'))
    # writer.writerow(['video', 'sdr', 'sisdr'])
    # for video_path in tqdm(glob.glob(os.path.join(sep_data_root, dset, '*', '*'))):
    #     part = video_path.split('/')[-2]
    #     video = video_path.split('/')[-1]
    #     original_audio_path = os.path.join(src_data_root, dset, part, video + '.wav')
    #     separated_audio_paths = glob.glob(video_path + '/*')
    #     sdr, sisdr = calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths)
    #     writer.writerow([video, f'{sdr:.3f}', f'{sisdr:.3f}'])