Tianhao Wang
first commit
dbbd709
raw
history blame
3.93 kB
import os
import csv
import glob
from tqdm import tqdm
import torch
import torchaudio
from torchmetrics.audio import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
def calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths):
"""
计算叠加的音频与原始音频之间的 SDR 和 SI-SDR。
参数:
- original_audio_path: str, 原始音频文件路径。
- separated_audio_paths: List[str], 分割后的音频片段文件路径列表。
返回:
- sdr: float, SDR 值。
- sisdr: float, SI-SDR 值。
"""
# 加载原始音频
original_waveform, sample_rate = torchaudio.load(original_audio_path)
# 初始化叠加的音频波形
combined_waveform = None
# 加载并叠加分割的音频片段
for path in separated_audio_paths:
separated_waveform, _ = torchaudio.load(path)
# 对齐片段长度
min_length = min(original_waveform.size(1), separated_waveform.size(1))
separated_waveform = separated_waveform[:, :min_length]
# 初始化或叠加音频
if combined_waveform is None:
combined_waveform = separated_waveform
else:
combined_waveform = combined_waveform[:, :min_length] + separated_waveform
# 确保合并后的音频和原始音频的长度一致
min_length = min(original_waveform.size(1), combined_waveform.size(1))
original_waveform = original_waveform[:, :min_length]
combined_waveform = combined_waveform[:, :min_length]
# 计算 SI-SDR
sisdr_metric = ScaleInvariantSignalDistortionRatio()
sisdr = sisdr_metric(combined_waveform, original_waveform).item()
# 计算 SDR
sdr_metric = SignalDistortionRatio()
sdr = sdr_metric(combined_waveform, original_waveform).item()
# print(f"SI-SDR between original and combined audio: {sisdr} dB")
# print(f"SDR between original and combined audio: {sdr} dB")
return sdr, sisdr
if __name__ == "__main__":
# 示例: 指定原始音频和分割后的音频片段路径
# original_audio_path = "path_to_original_audio.wav"
# separated_audio_paths = [
# "path_to_segment_1.wav",
# "path_to_segment_2.wav",
# "path_to_segment_3.wav",
# ]
# # 计算 SDR 和 SI-SDR
# sdr, sisdr = calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths)
dset = 'balanced_train_segments'
# dset = 'eval_segments'
src_data_root = r'/data/sound/audioset/audios_32k'
sep_data_root = r'data_engine_infer/audioset_separation_child_label'
writer = csv.writer(open(os.path.join(sep_data_root, dset + '.csv'), 'w'))
writer.writerow(['video', 'sdr', 'sisdr'])
for video_path in tqdm(glob.glob(os.path.join(sep_data_root, dset, '*'))):
video = video_path.split('/')[-1]
original_audio_path = os.path.join(src_data_root, dset, video + '.wav')
separated_audio_paths = glob.glob(video_path + '/*')
sdr, sisdr = calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths)
writer.writerow([video, f'{sdr:.3f}', f'{sisdr:.3f}'])
# dset = 'unbalanced_train_segments'
# src_data_root = r'/data/sound/audioset/audios_32k'
# sep_data_root = r'data_engine_infer/audioset_separation_child_label'
# writer = csv.writer(open(os.path.join(sep_data_root, dset + '.csv'), 'w'))
# writer.writerow(['video', 'sdr', 'sisdr'])
# for video_path in tqdm(glob.glob(os.path.join(sep_data_root, dset, '*', '*'))):
# part = video_path.split('/')[-2]
# video = video_path.split('/')[-1]
# original_audio_path = os.path.join(src_data_root, dset, part, video + '.wav')
# separated_audio_paths = glob.glob(video_path + '/*')
# sdr, sisdr = calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths)
# writer.writerow([video, f'{sdr:.3f}', f'{sisdr:.3f}'])