File size: 4,616 Bytes
c41d9f4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | import os
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
import pytorch_lightning as pl
import random
import librosa
from os.path import basename, exists, join
from torch.utils.data import Dataset, DataLoader
import hydra
import utils
import torchaudio
from transformers import AutoFeatureExtractor
from torchaudio.transforms import Resample
from tqdm import tqdm
from torchaudio.transforms import Resample
class DataModule(pl.LightningDataModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
ocwd = hydra.utils.get_original_cwd()
self.ocwd = ocwd
def get_loader(self, phase):
phase_cfg = self.cfg.dataset.get(phase)
batch_size = phase_cfg.batch_size
ds = FSDataset(phase, self.cfg)
# ds = FSDataset_add_STFT(phase, self.cfg)
dl = DataLoader(ds,
batch_size=batch_size,
shuffle=phase_cfg.shuffle,
num_workers=28,
collate_fn=ds.collate_fn,
pin_memory=True,
persistent_workers=True)
return dl
def train_dataloader(self):
return self.get_loader('train')
def val_dataloader(self):
return self.get_loader('val')
def test_dataloader(self):
pass
class FSDataset(Dataset):
"""Dataset batching wav, mel
and other acoustic features
Args:
phase: train, val, test
cfg: hydra config
"""
def __init__(self, phase, cfg):
self.phase = phase
self.cfg = cfg
self.phase_cfg = cfg.dataset.get(phase)
self.ocwd = hydra.utils.get_original_cwd()
self.sr = cfg.dataset.sr
# self.filelist = utils.read_filelist(join(self.ocwd, self.phase_cfg.filelist))
self.filelist = self.get_filelist(self.phase_cfg.filelist)
self.min_audio_length = cfg.dataset.min_audio_length
self.feature_extractor = AutoFeatureExtractor.from_pretrained("/apdcephfs/private_jishengpeng2/work/shengpeng/research/X-Codec-2.0/ckpt/w2v-bert-2.0")
def __len__(self):
return len(self.filelist)
def load_wav(self, path):
wav, sr = librosa.load(path, sr=self.sr)
return wav
# def get_filelist(self, fpath):
# with open(fpath, 'r') as f:
# # flist = [l.strip() for l in f if l.strip()]
# flist = [l.strip().split('\t')[0] for l in f if l.strip()]
# return flist
def get_filelist(self, fpath):
with open(fpath, 'r') as f:
lines = f.readlines()
# 第一行作为 audio_root
self.audio_root = lines[0].strip()
flist = [l.strip().split('\t')[0] for l in lines[1:]
if l.strip() and l.strip().split('\t')[0].endswith(('.wav', '.flac', '.mp3'))]
return flist
def __getitem__(self, idx):
wavpath = self.filelist[idx]
wavpath_full = join(self.audio_root, wavpath)
# wav = self.load_wav(wavpath)
# wav = torch.from_numpy(wav)
wav,sr=torchaudio.load(wavpath_full)
if sr != 16000:
wav = Resample(sr, 16000)(wav)
wav = wav[0,:]
length = wav.shape[0]
# length = wav.shape[1]
if length < self.min_audio_length:
wav = F.pad(wav, (0, self.min_audio_length - length))
length = wav.shape[0]
i = random.randint(0, length-self.min_audio_length)
wav = wav[i:i+self.min_audio_length]
wav_pad = F.pad(wav, (160, 160))
feat = self.feature_extractor(wav_pad, sampling_rate=16000, return_tensors="pt") .data['input_features']
out = {
'wav': wav,
'feat': feat,
# 'paths': wavpath_full
}
return out
def collate_fn(self, bs):
wavs = [b['wav'] for b in bs]
wavs = torch.stack(wavs)
feats = [b['feat'] for b in bs]
feats = torch.stack(feats)
out = {
'wav': wavs,
'feats': feats,
# 'paths': [b['paths'] for b in bs]
}
return out
@hydra.main(config_path='config', config_name='default', version_base=None)
def main(cfg):
data_module = DataModule(cfg)
train_loader = data_module.val_dataloader()
valid_filelist = []
for batch_idx, batch in enumerate(tqdm(train_loader, desc="Processing batches", unit="batch")):
wavs = batch['wav']
if __name__ == "__main__":
main()
|