import os import csv import glob from tqdm import tqdm import torch import torchaudio from torchmetrics.audio import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio def calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths): """ 计算叠加的音频与原始音频之间的 SDR 和 SI-SDR。 参数: - original_audio_path: str, 原始音频文件路径。 - separated_audio_paths: List[str], 分割后的音频片段文件路径列表。 返回: - sdr: float, SDR 值。 - sisdr: float, SI-SDR 值。 """ # 加载原始音频 original_waveform, sample_rate = torchaudio.load(original_audio_path) # 初始化叠加的音频波形 combined_waveform = None # 加载并叠加分割的音频片段 for path in separated_audio_paths: separated_waveform, _ = torchaudio.load(path) # 对齐片段长度 min_length = min(original_waveform.size(1), separated_waveform.size(1)) separated_waveform = separated_waveform[:, :min_length] # 初始化或叠加音频 if combined_waveform is None: combined_waveform = separated_waveform else: combined_waveform = combined_waveform[:, :min_length] + separated_waveform # 确保合并后的音频和原始音频的长度一致 min_length = min(original_waveform.size(1), combined_waveform.size(1)) original_waveform = original_waveform[:, :min_length] combined_waveform = combined_waveform[:, :min_length] # 计算 SI-SDR sisdr_metric = ScaleInvariantSignalDistortionRatio() sisdr = sisdr_metric(combined_waveform, original_waveform).item() # 计算 SDR sdr_metric = SignalDistortionRatio() sdr = sdr_metric(combined_waveform, original_waveform).item() # print(f"SI-SDR between original and combined audio: {sisdr} dB") # print(f"SDR between original and combined audio: {sdr} dB") return sdr, sisdr if __name__ == "__main__": # 示例: 指定原始音频和分割后的音频片段路径 # original_audio_path = "path_to_original_audio.wav" # separated_audio_paths = [ # "path_to_segment_1.wav", # "path_to_segment_2.wav", # "path_to_segment_3.wav", # ] # # 计算 SDR 和 SI-SDR # sdr, sisdr = calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths) dset = 'balanced_train_segments' # dset = 'eval_segments' src_data_root = r'/data/sound/audioset/audios_32k' sep_data_root = r'data_engine_infer/audioset_separation_child_label' writer = csv.writer(open(os.path.join(sep_data_root, dset + '.csv'), 'w')) writer.writerow(['video', 'sdr', 'sisdr']) for video_path in tqdm(glob.glob(os.path.join(sep_data_root, dset, '*'))): video = video_path.split('/')[-1] original_audio_path = os.path.join(src_data_root, dset, video + '.wav') separated_audio_paths = glob.glob(video_path + '/*') sdr, sisdr = calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths) writer.writerow([video, f'{sdr:.3f}', f'{sisdr:.3f}']) # dset = 'unbalanced_train_segments' # src_data_root = r'/data/sound/audioset/audios_32k' # sep_data_root = r'data_engine_infer/audioset_separation_child_label' # writer = csv.writer(open(os.path.join(sep_data_root, dset + '.csv'), 'w')) # writer.writerow(['video', 'sdr', 'sisdr']) # for video_path in tqdm(glob.glob(os.path.join(sep_data_root, dset, '*', '*'))): # part = video_path.split('/')[-2] # video = video_path.split('/')[-1] # original_audio_path = os.path.join(src_data_root, dset, part, video + '.wav') # separated_audio_paths = glob.glob(video_path + '/*') # sdr, sisdr = calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths) # writer.writerow([video, f'{sdr:.3f}', f'{sisdr:.3f}'])