|
|
from ldm.data.preprocess.NAT_mel import MelNet |
|
|
import os |
|
|
from tqdm import tqdm |
|
|
from glob import glob |
|
|
import math |
|
|
import pandas as pd |
|
|
import logging |
|
|
import math |
|
|
import audioread |
|
|
from tqdm.contrib.concurrent import process_map |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchaudio |
|
|
import numpy as np |
|
|
from torch.distributed import init_process_group |
|
|
from torch.utils.data import Dataset,DataLoader,DistributedSampler |
|
|
import torch.multiprocessing as mp |
|
|
from argparse import Namespace |
|
|
from multiprocessing import Pool |
|
|
import json |
|
|
|
|
|
|
|
|
class tsv_dataset(Dataset): |
|
|
def __init__(self,tsv_path,sr,mode='none',hop_size = None,target_mel_length = None) -> None: |
|
|
super().__init__() |
|
|
if os.path.isdir(tsv_path): |
|
|
files = glob(os.path.join(tsv_path,'*.tsv')) |
|
|
df = pd.concat([pd.read_csv(file,sep='\t') for file in files]) |
|
|
else: |
|
|
df = pd.read_csv(tsv_path,sep='\t') |
|
|
self.audio_paths = [] |
|
|
self.sr = sr |
|
|
self.mode = mode |
|
|
self.target_mel_length = target_mel_length |
|
|
self.hop_size = hop_size |
|
|
for t in tqdm(df.itertuples()): |
|
|
self.audio_paths.append(getattr(t,'audio_path')) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.audio_paths) |
|
|
|
|
|
def pad_wav(self,wav): |
|
|
|
|
|
wav_length = wav.shape[-1] |
|
|
assert wav_length > 100, "wav is too short, %s" % wav_length |
|
|
segment_length = (self.target_mel_length + 1) * self.hop_size |
|
|
if segment_length is None or wav_length == segment_length: |
|
|
return wav |
|
|
elif wav_length > segment_length: |
|
|
return wav[:,:segment_length] |
|
|
elif wav_length < segment_length: |
|
|
temp_wav = torch.zeros((1, segment_length),dtype=torch.float32) |
|
|
temp_wav[:, :wav_length] = wav |
|
|
return temp_wav |
|
|
|
|
|
|
|
|
def __getitem__(self, index): |
|
|
audio_path = self.audio_paths[index] |
|
|
wav, orisr = torchaudio.load(audio_path) |
|
|
if wav.shape[0] != 1: |
|
|
wav = wav.mean(0,keepdim=True) |
|
|
wav = torchaudio.functional.resample(wav, orig_freq=orisr, new_freq=self.sr) |
|
|
if self.mode == 'pad': |
|
|
assert self.target_mel_length is not None |
|
|
wav = self.pad_wav(wav) |
|
|
return audio_path,wav |
|
|
|
|
|
def process_audio_by_tsv(rank,args): |
|
|
if args.num_gpus > 1: |
|
|
init_process_group(backend=args.dist_config['dist_backend'], init_method=args.dist_config['dist_url'], |
|
|
world_size=args.dist_config['world_size'] * args.num_gpus, rank=rank) |
|
|
|
|
|
sr = args.audio_sample_rate |
|
|
dataset = tsv_dataset(args.tsv_path,sr = sr,mode=args.mode,hop_size=args.hop_size,target_mel_length=args.batch_max_length) |
|
|
sampler = DistributedSampler(dataset,shuffle=False) if args.num_gpus > 1 else None |
|
|
|
|
|
loader = DataLoader(dataset, sampler=sampler,batch_size=1, num_workers=16,drop_last=False) |
|
|
|
|
|
device = torch.device('cuda:{:d}'.format(rank)) |
|
|
|
|
|
mel_net = MelNet(args.__dict__) |
|
|
mel_net.to(device) |
|
|
|
|
|
|
|
|
|
|
|
loader = tqdm(loader) if rank == 0 else loader |
|
|
for batch in loader: |
|
|
audio_paths,wavs = batch |
|
|
wavs = wavs.to(device) |
|
|
if args.save_resample: |
|
|
for audio_path,wav in zip(audio_paths,wavs): |
|
|
psplits = audio_path.split('/') |
|
|
root,wav_name = psplits[0],psplits[-1] |
|
|
|
|
|
resample_root,resample_name = root+f'_{sr}',wav_name[:-4]+'_audio.npy' |
|
|
resample_dir_name = os.path.join(resample_root,*psplits[1:-1]) |
|
|
resample_path = os.path.join(resample_dir_name,resample_name) |
|
|
os.makedirs(resample_dir_name,exist_ok=True) |
|
|
np.save(resample_path,wav.cpu().numpy().squeeze(0)) |
|
|
|
|
|
if args.save_mel: |
|
|
mode = args.mode |
|
|
batch_max_length = args.batch_max_length |
|
|
|
|
|
for audio_path,wav in zip(audio_paths,wavs): |
|
|
psplits = audio_path.split('/') |
|
|
root,wav_name = psplits[0],psplits[-1] |
|
|
mel_root,mel_name = root+f'_mel{mode}{sr}nfft{args.fft_size}',wav_name[:-4]+'_mel.npy' |
|
|
mel_dir_name = os.path.join(mel_root,*psplits[1:-1]) |
|
|
mel_path = os.path.join(mel_dir_name,mel_name) |
|
|
if not os.path.exists(mel_path): |
|
|
mel_spec = mel_net(wav).cpu().numpy().squeeze(0) |
|
|
if mel_spec.shape[1] <= batch_max_length: |
|
|
if mode == 'tile': |
|
|
n_repeat = math.ceil((batch_max_length + 1) / mel_spec.shape[1]) |
|
|
mel_spec = np.tile(mel_spec,reps=(1,n_repeat)) |
|
|
elif mode == 'none' or mode == 'pad': |
|
|
pass |
|
|
else: |
|
|
raise ValueError(f'mode:{mode} is not supported') |
|
|
mel_spec = mel_spec[:,:batch_max_length] |
|
|
os.makedirs(mel_dir_name,exist_ok=True) |
|
|
np.save(mel_path,mel_spec) |
|
|
|
|
|
|
|
|
def split_list(i_list,num): |
|
|
each_num = math.ceil(i_list / num) |
|
|
result = [] |
|
|
for i in range(num): |
|
|
s = each_num * i |
|
|
e = (each_num * (i+1)) |
|
|
result.append(i_list[s:e]) |
|
|
return result |
|
|
|
|
|
|
|
|
def drop_bad_wav(item): |
|
|
index,path = item |
|
|
try: |
|
|
with audioread.audio_open(path) as f: |
|
|
totalsec = f.duration |
|
|
if totalsec < 0.1: |
|
|
return index |
|
|
except: |
|
|
print(f"corrupted wav:{path}") |
|
|
return index |
|
|
return False |
|
|
|
|
|
def drop_bad_wavs(tsv_path): |
|
|
df = pd.read_csv(tsv_path,sep='\t') |
|
|
item_list = [] |
|
|
for item in tqdm(df.itertuples()): |
|
|
item_list.append((item[0],getattr(item,'audio_path'))) |
|
|
|
|
|
r = process_map(drop_bad_wav,item_list,max_workers=16,chunksize=16) |
|
|
bad_indices = list(filter(lambda x:x!= False,r)) |
|
|
|
|
|
print(bad_indices) |
|
|
with open('bad_wavs.json','w') as f: |
|
|
x = [item_list[i] for i in bad_indices] |
|
|
json.dump(x,f) |
|
|
df = df.drop(bad_indices,axis=0) |
|
|
df.to_csv(tsv_path,sep='\t',index=False) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
logging.basicConfig(filename='example.log', level=logging.INFO, |
|
|
format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p') |
|
|
tsv_path = './musiccap.tsv' |
|
|
if os.path.isdir(tsv_path): |
|
|
files = glob(os.path.join(tsv_path,'*.tsv')) |
|
|
for file in files: |
|
|
drop_bad_wavs(file) |
|
|
else: |
|
|
drop_bad_wavs(tsv_path) |
|
|
num_gpus = 1 |
|
|
args = { |
|
|
'audio_sample_rate': 16000, |
|
|
'audio_num_mel_bins':80, |
|
|
'fft_size': 1024, |
|
|
'win_size': 1024, |
|
|
'hop_size': 256, |
|
|
'fmin': 0, |
|
|
'fmax': 8000, |
|
|
'batch_max_length': 1560, |
|
|
'tsv_path': tsv_path, |
|
|
'num_gpus': num_gpus, |
|
|
'mode': 'none', |
|
|
'save_resample':False, |
|
|
'save_mel' :True |
|
|
} |
|
|
args = Namespace(**args) |
|
|
args.dist_config = { |
|
|
"dist_backend": "nccl", |
|
|
"dist_url": "tcp://localhost:54189", |
|
|
"world_size": 1 |
|
|
} |
|
|
if args.num_gpus>1: |
|
|
mp.spawn(process_audio_by_tsv,nprocs=args.num_gpus,args=(args,)) |
|
|
else: |
|
|
process_audio_by_tsv(0,args=args) |
|
|
print("done") |
|
|
|
|
|
|