File size: 13,374 Bytes
dfd1909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
from typing import Optional, Literal, Union, Final, List
from numpy import ndarray

import os
from tqdm import tqdm
import numpy as np
import soundfile as sf
import librosa
from scipy.signal import resample_poly

try: import torch 
except: print('import error: torch')
try: from torch import Tensor
except: print('')
try: import torchaudio
except: print('import error: torch')
try: from pydub import AudioSegment  
except: print('import error: pydub')

from TorchJaekwon.Util.UtilData import UtilData

DATA_TYPE_MIN_MAX_DICT:Final[dict] = {'float32':(-1,1), 'float64':(-1,1), 'int16':(-2**15, 2**15-1), 'int32':(-2**31,2**31-1)}

class UtilAudio:
    @staticmethod
    def change_dtype(audio:ndarray,
                     current_dtype:Literal['float32', 'float64', 'int16', 'int32'],
                     target_dtype:Literal['float32', 'float64', 'int16', 'int32']
                     ) -> ndarray:
        audio = np.clip(audio, a_min = DATA_TYPE_MIN_MAX_DICT[current_dtype][0], a_max = DATA_TYPE_MIN_MAX_DICT[current_dtype][1])
        audio = audio / DATA_TYPE_MIN_MAX_DICT[current_dtype][1]
        audio = (audio * DATA_TYPE_MIN_MAX_DICT[target_dtype][1])
        audio = audio.astype(getattr(np,target_dtype))
        return audio
    
    @staticmethod
    def resample_audio(audio:Union[ndarray, Tensor], #[shape=(channel, num_samples) or (num_samples)]
                       origin_sr:int,
                       target_sr:int,
                       resample_module:Literal['librosa', 'resample_poly', 'torchaudio'] = 'librosa',
                       resample_type:str = "kaiser_fast",
                       audio_path:Optional[str] = None):
        if(origin_sr == target_sr): return audio
        #print(f"resample audio {origin_sr} to {target_sr}")
        if resample_module == 'librosa':
            return librosa.resample(audio, orig_sr=origin_sr, target_sr=target_sr, res_type=resample_type)
        elif resample_module == 'resample_poly':
            return resample_poly(x = audio, up = target_sr, down = origin_sr)
        elif resample_module == 'torchaudio':
            #transforms.Resample precomputes and caches the kernel used for resampling, while functional.resample computes it on the fly
            #so using torchaudio.transforms.Resample will result in a speedup when resampling multiple waveforms using the same parameters
            return torchaudio.transforms.Resample(orig_freq = origin_sr, new_freq = target_sr)(audio)
    
    @staticmethod
    def read(audio_path:str,
             sample_rate:Optional[int] = None,
             mono:Optional[bool] = None,
             start_idx:int = 0,
             end_idx:Optional[int] = None,
             module_name:Literal['soundfile','librosa', 'torchaudio'] = 'torchaudio',
             return_type:Union[ndarray, Tensor] = ndarray
            ) -> Union[ndarray, Tensor]: #[shape=(channel, num_samples) or (num_samples)]
        
        if module_name == "soundfile":
            audio_data, original_samplerate = sf.read(audio_path)
            if len(audio_data.shape) > 1 : audio_data = audio_data.T

            if sample_rate is not None and sample_rate != original_samplerate:
                #print(f"resample audio {original_samplerate} to {sample_rate}")
                audio_data = UtilAudio.resample_audio(audio_data,original_samplerate,sample_rate)

        elif module_name == "librosa":
            print(f"read audio sr: {sample_rate}")
            audio_data, original_samplerate = librosa.load( audio_path, sr=sample_rate, mono=mono)
        
        elif module_name == 'torchaudio':
            if end_idx is not None: assert end_idx > start_idx, f'[Error] end_idx must be larger than start_idx'
            #[channel, time], int
            audio_data, original_samplerate = torchaudio.load(audio_path, 
                                                              frame_offset = start_idx, 
                                                              num_frames = -1 if end_idx is None else end_idx - start_idx)
            if sample_rate is not None and sample_rate != original_samplerate:
                audio_data = UtilAudio.resample_audio(audio = audio_data, origin_sr=original_samplerate, target_sr = sample_rate, resample_module='torchaudio', audio_path = audio_path)
                
        if mono is not None:
            if mono and len(audio_data.shape) == 2 and audio_data.shape[0] == 2:
                audio_data = torch.mean(audio_data,axis=0)  if isinstance(audio_data, torch.Tensor) else np.mean(audio_data,axis=0) 
            elif not mono and (len(audio_data.shape) == 1 or audio_data.shape[0] == 1):
                stereo_audio = torch.zeros((2,len(audio_data.squeeze())))
                stereo_audio[0,...] = audio_data.squeeze()
                stereo_audio[1,...] = audio_data.squeeze()
                audio_data = stereo_audio
        
        assert ((len(audio_data.shape)==1) or ((len(audio_data.shape)==2) and audio_data.shape[0] in [1,2])),f'[read audio shape problem] path: {audio_path} shape: {audio_data.shape}'
            
        return audio_data, original_samplerate if sample_rate is None else sample_rate
    
    @staticmethod
    def write(audio_path:str,
              audio:Union[ndarray, Tensor],
              sample_rate:int,
              ) -> None:
        os.makedirs(os.path.dirname(audio_path), exist_ok=True)
        if isinstance(audio, Tensor):
            audio = audio.squeeze().cpu().detach().numpy()
        assert len(audio.shape) <= 2, f'[Error] shape of {audio_path}: {audio.shape}'
        if len(audio.shape) == 2 and audio.shape[0] < audio.shape[1]: audio = audio.T
        sf.write(file = audio_path, data = audio, samplerate = sample_rate)
    
    @staticmethod
    def stereo_to_mono(audio_data:Union[ndarray, Tensor]) -> Union[ndarray, Tensor]:
        audio_data = np.mean(audio_data,axis=1)
        return audio_data
    
    @staticmethod
    def mono_to_stereo(audio_data:Union[ndarray, Tensor]) -> Union[ndarray, Tensor]:
        stereo_audio = np.zeros((2,len(audio_data)))
        stereo_audio[0,...] = audio_data
        stereo_audio[1,...] = audio_data
        audio_data = stereo_audio
        return audio_data

    @staticmethod
    def normalize_volume(audio_input:ndarray,sr:int, target_dBFS = -30):
        audio = UtilAudio.change_dtype(audio=audio_input,current_dtype='float64',target_dtype='int32')#UtilAudio.float64_to_int32(audio_input)
        audio_segment = AudioSegment(audio.tobytes(), frame_rate=sr, sample_width=audio.dtype.itemsize, channels=1)
        change_in_dBFS = target_dBFS - audio_segment.dBFS
        normalizedsound = audio_segment.apply_gain(change_in_dBFS)
        return UtilAudio.change_dtype(audio=np.array(normalizedsound.get_array_of_samples()),current_dtype='int32',target_dtype='float64') #UtilAudio.int32_to_float64(np.array(normalizedsound.get_array_of_samples()))
    
    @staticmethod
    def normalize_by_fro_norm(audio_input:Tensor #[batch, channel, time]
                              ) -> Tensor:
        original_shape:tuple = audio_input.shape
        audio = audio_input.reshape(original_shape[0], -1)
        audio = audio/torch.norm(audio, p="fro", dim=1, keepdim=True)
        audio = audio.reshape(*original_shape)
        return audio

    @staticmethod
    def energy_unify(estimated, original, eps = 1e-12):
        target = UtilAudio.pow_norm(estimated, original) * original
        target /= UtilAudio.pow_p_norm(original) + eps
        return estimated, target

    @staticmethod
    def pow_norm(s1, s2):
        return torch.sum(s1 * s2)

    @staticmethod
    def pow_p_norm(signal):
        return torch.pow(torch.norm(signal, p=2), 2)
    
    @staticmethod
    def get_segment_index_list(audio:ndarray, #[time]
                               sample_rate:int,
                               segment_sample_length:int,
                               hop_seconds:float = 0.1
                               ) -> list:
        begin_sample:int = 0
        hop_samples = int(hop_seconds * sample_rate)
        segment_index_list = list()
        while (begin_sample == 0) or (begin_sample + segment_sample_length < len(audio)):
            segment_index_list.append({'begin':begin_sample, 'end':begin_sample + segment_sample_length})
            begin_sample += hop_samples
        return segment_index_list
    
    @staticmethod
    def audio_to_batch(audio:Tensor, #[Length]
                       segment_length:int,
                       overlap_length:int = 48000 #recommend: int(sr * 0.5)
                       ):
        assert len(audio.shape) == 1, f'[Error] audio shape must be 1, but {audio.shape}'
        start_idx:int = 0
        audio_list = list()
        while start_idx < len(audio):
            audio_segment = audio[start_idx:start_idx+segment_length]
            audio_segment = UtilData.fix_length(audio_segment, segment_length)
            audio_list.append(audio_segment)
            start_idx += segment_length - overlap_length
        return torch.stack(audio_list)
    
    @staticmethod
    def merge_batch_w_cross_fade(batch_audio:Union[List[ndarray],ndarray,Tensor],
                                 segment_length:int,
                                 overlap_length:int = 48000 #recommend: int(sr * 0.5)
                                 ) -> ndarray:
        '''
        reference from https://github.com/nkandpa2/music_enhancement/blob/master/scripts/generate_from_wav.py
        '''
        if isinstance(batch_audio, ndarray) and len(batch_audio.shape) == 1:
            batch_audio = [batch_audio]
        output_audio_length:int = len(batch_audio) * segment_length - (len(batch_audio) - 1) * overlap_length
        output_audio:Union[ndarray,Tensor] = torch.zeros(output_audio_length) if isinstance(batch_audio, torch.Tensor) else np.zeros(output_audio_length)
        hop_length:int = segment_length - overlap_length
        
        cross_fade_in:ndarray = np.linspace(0, 1, overlap_length)
        cross_fade_out:ndarray = 1 - cross_fade_in
        if isinstance(batch_audio, torch.Tensor):
            cross_fade_in = torch.tensor(cross_fade_in, device = batch_audio.device)
            cross_fade_out = torch.tensor(cross_fade_out, device = batch_audio.device)
        
        for i in range(0,len(batch_audio)):
            start_idx:int = i * hop_length
            if i != 0:
                batch_audio[i][:overlap_length] *= cross_fade_in
            if i != len(batch_audio) - 1:
                batch_audio[i][-overlap_length:] *= cross_fade_out
            output_audio[start_idx:start_idx+segment_length] += batch_audio[i]
        return output_audio
    
    @staticmethod
    def analyze_audio_dataset(data_dir:str, 
                              result_save_dir:str,
                              sanity_check_sr:Union[int,List[int]] = None,
                              save_each_meta:bool = False
                              ) -> None:
        total_meta_dict:dict = {
            'total_duration_second': 0,
            'total_duration_minutes': 0,
            'total_duration_hours': 0,

            'longest_sample_meta': {
                'file_name': '',
                'duration_second':0
            },

            'error_file_list': list()
        }
        if sanity_check_sr is not None: total_meta_dict['sample_rate'] = sanity_check_sr
        
        audio_meta_data_list = UtilData.walk(dir_name=data_dir, ext=['.wav', '.mp3', '.flac'])
        for meta_data in tqdm(audio_meta_data_list):
            try:
                audio, sr = UtilAudio.read(meta_data['file_path'], mono=True)
            except:
                print(f'Error: {meta_data["file_path"]}')
                total_meta_dict['error_file_list'].append(meta_data['file_path'])
                continue
            if sanity_check_sr is not None: 
                if isinstance(sanity_check_sr, int): assert sr == sanity_check_sr, f'''{meta_data['file_path']}'s sample rate is {sr}'''
                if isinstance(sanity_check_sr, list): assert sr in sanity_check_sr, f'''{meta_data['file_path']}'s sample rate is {sr}'''
            
            meta_data_of_this_file = {
                'file_name': meta_data['file_name'],
                'file_path': os.path.abspath(meta_data['file_path']),
                'sample_length': audio.shape[-1],
                'sample_rate': sr,
            }
            meta_data_of_this_file['duration_second'] = meta_data_of_this_file['sample_length'] / meta_data_of_this_file['sample_rate']
            
            save_dir:str = meta_data['dir_path'].replace(data_dir, result_save_dir)
            if save_each_meta: UtilData.pickle_save(f'''{save_dir}/{meta_data['file_name']}.pkl''', meta_data_of_this_file)

            total_meta_dict['total_duration_second'] += meta_data_of_this_file['duration_second']
            if total_meta_dict['longest_sample_meta']['duration_second'] < meta_data_of_this_file['duration_second']:
                total_meta_dict['longest_sample_meta'] = meta_data_of_this_file
        
        total_meta_dict['total_duration_minutes'] = total_meta_dict['total_duration_second'] / 60
        total_meta_dict['total_duration_hours'] = total_meta_dict['total_duration_second'] / 3600
        UtilData.yaml_save(save_path = f'{result_save_dir}/meta.yaml', data = total_meta_dict)