ChristophSchuhmann's picture
Add model code, inference script, and examples
dfd1909 verified
from typing import Optional, Literal, Union, Final, List
from numpy import ndarray
import os
from tqdm import tqdm
import numpy as np
import soundfile as sf
import librosa
from scipy.signal import resample_poly
try: import torch
except: print('import error: torch')
try: from torch import Tensor
except: print('')
try: import torchaudio
except: print('import error: torch')
try: from pydub import AudioSegment
except: print('import error: pydub')
from TorchJaekwon.Util.UtilData import UtilData
DATA_TYPE_MIN_MAX_DICT:Final[dict] = {'float32':(-1,1), 'float64':(-1,1), 'int16':(-2**15, 2**15-1), 'int32':(-2**31,2**31-1)}
class UtilAudio:
@staticmethod
def change_dtype(audio:ndarray,
current_dtype:Literal['float32', 'float64', 'int16', 'int32'],
target_dtype:Literal['float32', 'float64', 'int16', 'int32']
) -> ndarray:
audio = np.clip(audio, a_min = DATA_TYPE_MIN_MAX_DICT[current_dtype][0], a_max = DATA_TYPE_MIN_MAX_DICT[current_dtype][1])
audio = audio / DATA_TYPE_MIN_MAX_DICT[current_dtype][1]
audio = (audio * DATA_TYPE_MIN_MAX_DICT[target_dtype][1])
audio = audio.astype(getattr(np,target_dtype))
return audio
@staticmethod
def resample_audio(audio:Union[ndarray, Tensor], #[shape=(channel, num_samples) or (num_samples)]
origin_sr:int,
target_sr:int,
resample_module:Literal['librosa', 'resample_poly', 'torchaudio'] = 'librosa',
resample_type:str = "kaiser_fast",
audio_path:Optional[str] = None):
if(origin_sr == target_sr): return audio
#print(f"resample audio {origin_sr} to {target_sr}")
if resample_module == 'librosa':
return librosa.resample(audio, orig_sr=origin_sr, target_sr=target_sr, res_type=resample_type)
elif resample_module == 'resample_poly':
return resample_poly(x = audio, up = target_sr, down = origin_sr)
elif resample_module == 'torchaudio':
#transforms.Resample precomputes and caches the kernel used for resampling, while functional.resample computes it on the fly
#so using torchaudio.transforms.Resample will result in a speedup when resampling multiple waveforms using the same parameters
return torchaudio.transforms.Resample(orig_freq = origin_sr, new_freq = target_sr)(audio)
@staticmethod
def read(audio_path:str,
sample_rate:Optional[int] = None,
mono:Optional[bool] = None,
start_idx:int = 0,
end_idx:Optional[int] = None,
module_name:Literal['soundfile','librosa', 'torchaudio'] = 'torchaudio',
return_type:Union[ndarray, Tensor] = ndarray
) -> Union[ndarray, Tensor]: #[shape=(channel, num_samples) or (num_samples)]
if module_name == "soundfile":
audio_data, original_samplerate = sf.read(audio_path)
if len(audio_data.shape) > 1 : audio_data = audio_data.T
if sample_rate is not None and sample_rate != original_samplerate:
#print(f"resample audio {original_samplerate} to {sample_rate}")
audio_data = UtilAudio.resample_audio(audio_data,original_samplerate,sample_rate)
elif module_name == "librosa":
print(f"read audio sr: {sample_rate}")
audio_data, original_samplerate = librosa.load( audio_path, sr=sample_rate, mono=mono)
elif module_name == 'torchaudio':
if end_idx is not None: assert end_idx > start_idx, f'[Error] end_idx must be larger than start_idx'
#[channel, time], int
audio_data, original_samplerate = torchaudio.load(audio_path,
frame_offset = start_idx,
num_frames = -1 if end_idx is None else end_idx - start_idx)
if sample_rate is not None and sample_rate != original_samplerate:
audio_data = UtilAudio.resample_audio(audio = audio_data, origin_sr=original_samplerate, target_sr = sample_rate, resample_module='torchaudio', audio_path = audio_path)
if mono is not None:
if mono and len(audio_data.shape) == 2 and audio_data.shape[0] == 2:
audio_data = torch.mean(audio_data,axis=0) if isinstance(audio_data, torch.Tensor) else np.mean(audio_data,axis=0)
elif not mono and (len(audio_data.shape) == 1 or audio_data.shape[0] == 1):
stereo_audio = torch.zeros((2,len(audio_data.squeeze())))
stereo_audio[0,...] = audio_data.squeeze()
stereo_audio[1,...] = audio_data.squeeze()
audio_data = stereo_audio
assert ((len(audio_data.shape)==1) or ((len(audio_data.shape)==2) and audio_data.shape[0] in [1,2])),f'[read audio shape problem] path: {audio_path} shape: {audio_data.shape}'
return audio_data, original_samplerate if sample_rate is None else sample_rate
@staticmethod
def write(audio_path:str,
audio:Union[ndarray, Tensor],
sample_rate:int,
) -> None:
os.makedirs(os.path.dirname(audio_path), exist_ok=True)
if isinstance(audio, Tensor):
audio = audio.squeeze().cpu().detach().numpy()
assert len(audio.shape) <= 2, f'[Error] shape of {audio_path}: {audio.shape}'
if len(audio.shape) == 2 and audio.shape[0] < audio.shape[1]: audio = audio.T
sf.write(file = audio_path, data = audio, samplerate = sample_rate)
@staticmethod
def stereo_to_mono(audio_data:Union[ndarray, Tensor]) -> Union[ndarray, Tensor]:
audio_data = np.mean(audio_data,axis=1)
return audio_data
@staticmethod
def mono_to_stereo(audio_data:Union[ndarray, Tensor]) -> Union[ndarray, Tensor]:
stereo_audio = np.zeros((2,len(audio_data)))
stereo_audio[0,...] = audio_data
stereo_audio[1,...] = audio_data
audio_data = stereo_audio
return audio_data
@staticmethod
def normalize_volume(audio_input:ndarray,sr:int, target_dBFS = -30):
audio = UtilAudio.change_dtype(audio=audio_input,current_dtype='float64',target_dtype='int32')#UtilAudio.float64_to_int32(audio_input)
audio_segment = AudioSegment(audio.tobytes(), frame_rate=sr, sample_width=audio.dtype.itemsize, channels=1)
change_in_dBFS = target_dBFS - audio_segment.dBFS
normalizedsound = audio_segment.apply_gain(change_in_dBFS)
return UtilAudio.change_dtype(audio=np.array(normalizedsound.get_array_of_samples()),current_dtype='int32',target_dtype='float64') #UtilAudio.int32_to_float64(np.array(normalizedsound.get_array_of_samples()))
@staticmethod
def normalize_by_fro_norm(audio_input:Tensor #[batch, channel, time]
) -> Tensor:
original_shape:tuple = audio_input.shape
audio = audio_input.reshape(original_shape[0], -1)
audio = audio/torch.norm(audio, p="fro", dim=1, keepdim=True)
audio = audio.reshape(*original_shape)
return audio
@staticmethod
def energy_unify(estimated, original, eps = 1e-12):
target = UtilAudio.pow_norm(estimated, original) * original
target /= UtilAudio.pow_p_norm(original) + eps
return estimated, target
@staticmethod
def pow_norm(s1, s2):
return torch.sum(s1 * s2)
@staticmethod
def pow_p_norm(signal):
return torch.pow(torch.norm(signal, p=2), 2)
@staticmethod
def get_segment_index_list(audio:ndarray, #[time]
sample_rate:int,
segment_sample_length:int,
hop_seconds:float = 0.1
) -> list:
begin_sample:int = 0
hop_samples = int(hop_seconds * sample_rate)
segment_index_list = list()
while (begin_sample == 0) or (begin_sample + segment_sample_length < len(audio)):
segment_index_list.append({'begin':begin_sample, 'end':begin_sample + segment_sample_length})
begin_sample += hop_samples
return segment_index_list
@staticmethod
def audio_to_batch(audio:Tensor, #[Length]
segment_length:int,
overlap_length:int = 48000 #recommend: int(sr * 0.5)
):
assert len(audio.shape) == 1, f'[Error] audio shape must be 1, but {audio.shape}'
start_idx:int = 0
audio_list = list()
while start_idx < len(audio):
audio_segment = audio[start_idx:start_idx+segment_length]
audio_segment = UtilData.fix_length(audio_segment, segment_length)
audio_list.append(audio_segment)
start_idx += segment_length - overlap_length
return torch.stack(audio_list)
@staticmethod
def merge_batch_w_cross_fade(batch_audio:Union[List[ndarray],ndarray,Tensor],
segment_length:int,
overlap_length:int = 48000 #recommend: int(sr * 0.5)
) -> ndarray:
'''
reference from https://github.com/nkandpa2/music_enhancement/blob/master/scripts/generate_from_wav.py
'''
if isinstance(batch_audio, ndarray) and len(batch_audio.shape) == 1:
batch_audio = [batch_audio]
output_audio_length:int = len(batch_audio) * segment_length - (len(batch_audio) - 1) * overlap_length
output_audio:Union[ndarray,Tensor] = torch.zeros(output_audio_length) if isinstance(batch_audio, torch.Tensor) else np.zeros(output_audio_length)
hop_length:int = segment_length - overlap_length
cross_fade_in:ndarray = np.linspace(0, 1, overlap_length)
cross_fade_out:ndarray = 1 - cross_fade_in
if isinstance(batch_audio, torch.Tensor):
cross_fade_in = torch.tensor(cross_fade_in, device = batch_audio.device)
cross_fade_out = torch.tensor(cross_fade_out, device = batch_audio.device)
for i in range(0,len(batch_audio)):
start_idx:int = i * hop_length
if i != 0:
batch_audio[i][:overlap_length] *= cross_fade_in
if i != len(batch_audio) - 1:
batch_audio[i][-overlap_length:] *= cross_fade_out
output_audio[start_idx:start_idx+segment_length] += batch_audio[i]
return output_audio
@staticmethod
def analyze_audio_dataset(data_dir:str,
result_save_dir:str,
sanity_check_sr:Union[int,List[int]] = None,
save_each_meta:bool = False
) -> None:
total_meta_dict:dict = {
'total_duration_second': 0,
'total_duration_minutes': 0,
'total_duration_hours': 0,
'longest_sample_meta': {
'file_name': '',
'duration_second':0
},
'error_file_list': list()
}
if sanity_check_sr is not None: total_meta_dict['sample_rate'] = sanity_check_sr
audio_meta_data_list = UtilData.walk(dir_name=data_dir, ext=['.wav', '.mp3', '.flac'])
for meta_data in tqdm(audio_meta_data_list):
try:
audio, sr = UtilAudio.read(meta_data['file_path'], mono=True)
except:
print(f'Error: {meta_data["file_path"]}')
total_meta_dict['error_file_list'].append(meta_data['file_path'])
continue
if sanity_check_sr is not None:
if isinstance(sanity_check_sr, int): assert sr == sanity_check_sr, f'''{meta_data['file_path']}'s sample rate is {sr}'''
if isinstance(sanity_check_sr, list): assert sr in sanity_check_sr, f'''{meta_data['file_path']}'s sample rate is {sr}'''
meta_data_of_this_file = {
'file_name': meta_data['file_name'],
'file_path': os.path.abspath(meta_data['file_path']),
'sample_length': audio.shape[-1],
'sample_rate': sr,
}
meta_data_of_this_file['duration_second'] = meta_data_of_this_file['sample_length'] / meta_data_of_this_file['sample_rate']
save_dir:str = meta_data['dir_path'].replace(data_dir, result_save_dir)
if save_each_meta: UtilData.pickle_save(f'''{save_dir}/{meta_data['file_name']}.pkl''', meta_data_of_this_file)
total_meta_dict['total_duration_second'] += meta_data_of_this_file['duration_second']
if total_meta_dict['longest_sample_meta']['duration_second'] < meta_data_of_this_file['duration_second']:
total_meta_dict['longest_sample_meta'] = meta_data_of_this_file
total_meta_dict['total_duration_minutes'] = total_meta_dict['total_duration_second'] / 60
total_meta_dict['total_duration_hours'] = total_meta_dict['total_duration_second'] / 3600
UtilData.yaml_save(save_path = f'{result_save_dir}/meta.yaml', data = total_meta_dict)