|
|
import os |
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader, random_split, Subset |
|
|
import torchaudio |
|
|
from sklearn.model_selection import KFold |
|
|
import glob |
|
|
import numpy as np |
|
|
from utils.utils import repeat_expand |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_file_name(path): |
|
|
normalized_path = os.path.normpath(path) |
|
|
path_parts = normalized_path.split(os.sep) |
|
|
try: |
|
|
|
|
|
spk_name = path_parts[-2] |
|
|
wav_name = path_parts[-1].split('.')[0] |
|
|
file_name = f'{spk_name}_{wav_name}' |
|
|
except IndexError: |
|
|
|
|
|
file_name = path |
|
|
return file_name |
|
|
|
|
|
def wav_pad(wav, multiple=200): |
|
|
batch, seq_len = wav.shape |
|
|
padded_len = ((seq_len + (multiple-1)) // multiple) * multiple |
|
|
padded_wav = repeat_expand(wav, padded_len) |
|
|
return padded_wav |
|
|
|
|
|
class AudioDataset(torch.utils.data.Dataset): |
|
|
def __init__(self, audio_paths, transform=None): |
|
|
""" |
|
|
初始化数据集。 |
|
|
:param audio_paths: 音频文件的路径列表。 |
|
|
:param transform: 应用于每个音频样本的可选变换。 |
|
|
""" |
|
|
self.audio_paths = audio_paths |
|
|
self.transform = transform |
|
|
|
|
|
def __len__(self): |
|
|
""" |
|
|
返回数据集中样本的数量。 |
|
|
""" |
|
|
return len(self.audio_paths) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
""" |
|
|
根据索引获取音频样本。 |
|
|
""" |
|
|
if idx >= len(self.audio_paths): |
|
|
raise IndexError("Index out of bounds") |
|
|
|
|
|
audio_16k_path = self.audio_paths[idx] |
|
|
mel_44k_path = os.path.splitext(audio_16k_path)[0].replace('audio_16k', 'mel_44k') + '.npy' |
|
|
audio_44k_path = os.path.splitext(audio_16k_path)[0].replace('audio_16k', 'audio_44k') + '.wav' |
|
|
vq_post_path = os.path.splitext(audio_16k_path)[0].replace('audio_16k', 'vq_post') + '.npy' |
|
|
spk_path = os.path.splitext(audio_16k_path)[0].replace('audio_16k', 'spk') + '.npy' |
|
|
wav_44k, _ = torchaudio.load(audio_44k_path) |
|
|
|
|
|
|
|
|
wav_44k = wav_pad(wav_44k) |
|
|
|
|
|
|
|
|
if not os.path.isfile(mel_44k_path): |
|
|
raise FileNotFoundError(f"Mel spectrogram file not found: {mel_44k_path}") |
|
|
|
|
|
mel_44k = torch.from_numpy(np.load(mel_44k_path)) |
|
|
vq_post = torch.from_numpy(np.load(vq_post_path)) |
|
|
spk = torch.from_numpy(np.load(spk_path)) |
|
|
file_name = get_file_name(audio_16k_path) |
|
|
|
|
|
return wav_44k, mel_44k, vq_post, spk, file_name |
|
|
|
|
|
def load_audio_data(data_path, batch_size=64, validation_ratio=0.1, demo_num=0): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_audio_16k_paths = glob.glob(os.path.join(data_path, 'train', 'audio_16k/**', '*.wav')) |
|
|
test_audio_16k_paths = glob.glob(os.path.join(data_path, 'test', 'audio_16k/**', '*.wav')) |
|
|
|
|
|
if type(demo_num) is int and demo_num > 0: |
|
|
try: |
|
|
train_audio_16k_paths = train_audio_16k_paths[:demo_num] |
|
|
test_audio_16k_paths = test_audio_16k_paths[:demo_num//9] |
|
|
except: |
|
|
raise ValueError |
|
|
dataset = AudioDataset(train_audio_16k_paths) |
|
|
|
|
|
|
|
|
total_size = len(dataset) |
|
|
val_size = int(validation_ratio * total_size) |
|
|
train_size = total_size - val_size |
|
|
|
|
|
train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) |
|
|
test_dataset = AudioDataset(test_audio_16k_paths) |
|
|
|
|
|
|
|
|
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) |
|
|
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True) |
|
|
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) |
|
|
|
|
|
return train_loader, val_loader, test_loader |
|
|
|
|
|
def load_audio_data_k_fold(data_path, k_folds=5, batch_size=64, demo_num=0): |
|
|
audio_paths = [os.path.join(data_path, f) for f in os.listdir(data_path) if f.endswith('.wav')] |
|
|
|
|
|
if type(demo_num) is int and demo_num > 0: |
|
|
try: |
|
|
audio_paths = audio_paths[:demo_num] |
|
|
except: |
|
|
raise ValueError |
|
|
|
|
|
|
|
|
full_dataset = AudioDataset(audio_paths) |
|
|
|
|
|
|
|
|
kf = KFold(n_splits=k_folds, shuffle=True, random_state=42) |
|
|
|
|
|
fold_loaders = [] |
|
|
|
|
|
for train_index, val_index in kf.split(full_dataset): |
|
|
|
|
|
train_dataset = Subset(full_dataset, train_index) |
|
|
val_dataset = Subset(full_dataset, val_index) |
|
|
|
|
|
|
|
|
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) |
|
|
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False) |
|
|
|
|
|
|
|
|
fold_loaders.append((train_loader, val_loader)) |
|
|
|
|
|
|
|
|
test_loader = DataLoader(dataset=full_dataset, batch_size=batch_size, shuffle=False) |
|
|
|
|
|
return fold_loaders, test_loader |
|
|
|
|
|
class FocalLoss: |
|
|
def __init__(self, alpha_t=None, gamma=0): |
|
|
""" |
|
|
:param alpha_t: A list of weights for each class |
|
|
:param gamma: |
|
|
""" |
|
|
self.alpha_t = torch.tensor(alpha_t) if alpha_t else None |
|
|
self.gamma = gamma |
|
|
|
|
|
def __call__(self, outputs, targets): |
|
|
if self.alpha_t is None and self.gamma == 0: |
|
|
focal_loss = torch.nn.functional.cross_entropy(outputs, targets) |
|
|
|
|
|
elif self.alpha_t is not None and self.gamma == 0: |
|
|
if self.alpha_t.device != outputs.device: |
|
|
self.alpha_t = self.alpha_t.to(outputs) |
|
|
focal_loss = torch.nn.functional.cross_entropy(outputs, targets, |
|
|
weight=self.alpha_t) |
|
|
|
|
|
elif self.alpha_t is None and self.gamma != 0: |
|
|
ce_loss = torch.nn.functional.cross_entropy(outputs, targets, reduction='none') |
|
|
p_t = torch.exp(-ce_loss) |
|
|
focal_loss = ((1 - p_t) ** self.gamma * ce_loss).mean() |
|
|
|
|
|
elif self.alpha_t is not None and self.gamma != 0: |
|
|
if self.alpha_t.device != outputs.device: |
|
|
self.alpha_t = self.alpha_t.to(outputs) |
|
|
ce_loss = torch.nn.functional.cross_entropy(outputs, targets, reduction='none') |
|
|
p_t = torch.exp(-ce_loss) |
|
|
ce_loss = torch.nn.functional.cross_entropy(outputs, targets, |
|
|
weight=self.alpha_t, reduction='none') |
|
|
focal_loss = ((1 - p_t) ** self.gamma * ce_loss).mean() |
|
|
|
|
|
return focal_loss |