| | import torch
|
| | import torchaudio
|
| | import os
|
| | from torch.utils.data import Dataset
|
| | from .text_encoder import TextEncoder
|
| | from .config import HexaConfig
|
| |
|
| | class HexaDataset(Dataset):
|
| | """
|
| | Real Dataset Loader for Hexa TTS.
|
| | Expects a directory structure:
|
| | /data_root
|
| | /wavs/
|
| | metadata.csv (formatted: filename|text)
|
| | """
|
| | def __init__(self, root_dir, config: HexaConfig, train=True):
|
| | self.root_dir = root_dir
|
| | self.config = config
|
| | self.encoder = TextEncoder()
|
| |
|
| | self.wav_dir = os.path.join(root_dir, "wavs")
|
| | self.metadata_path = os.path.join(root_dir, "metadata.csv")
|
| |
|
| | self.files = []
|
| | if os.path.exists(self.metadata_path):
|
| | with open(self.metadata_path, 'r', encoding='utf-8') as f:
|
| | for line in f:
|
| | parts = line.strip().split('|')
|
| | if len(parts) >= 2:
|
| | self.files.append((parts[0], parts[1]))
|
| | else:
|
| | print(f"Warning: Metadata not found at {self.metadata_path}")
|
| |
|
| |
|
| | self.mel_transform = torchaudio.transforms.MelSpectrogram(
|
| | sample_rate=config.sample_rate,
|
| | n_fft=config.n_fft,
|
| | win_length=config.win_length,
|
| | hop_length=config.hop_length,
|
| | n_mels=config.n_mel_channels
|
| | )
|
| |
|
| | def __len__(self):
|
| | return len(self.files)
|
| |
|
| | def __getitem__(self, idx):
|
| | filename, text = self.files[idx]
|
| | wav_path = os.path.join(self.wav_dir, filename + ".wav")
|
| |
|
| |
|
| | waveform, sr = torchaudio.load(wav_path)
|
| |
|
| |
|
| | if sr != self.config.sample_rate:
|
| | resampler = torchaudio.transforms.Resample(sr, self.config.sample_rate)
|
| | waveform = resampler(waveform)
|
| |
|
| |
|
| | mel = self.mel_transform(waveform)
|
| | mel = mel.squeeze(0).transpose(0, 1)
|
| |
|
| |
|
| |
|
| | text_ids = self.encoder.preprocess(text, lang_code='en').squeeze(0)
|
| |
|
| |
|
| | speaker = torch.tensor(0)
|
| | lang = torch.tensor(0)
|
| | emotion = torch.tensor(0)
|
| |
|
| | return text_ids, speaker, lang, emotion, mel
|
| |
|
| | def collate_fn(batch):
|
| | """
|
| | Pads batch to longest sequence.
|
| | """
|
| |
|
| | batch.sort(key=lambda x: x[0].shape[0], reverse=True)
|
| |
|
| | text_ids, speakers, langs, emotions, mels = zip(*batch)
|
| |
|
| |
|
| | text_padded = torch.nn.utils.rnn.pad_sequence(text_ids, batch_first=True, padding_value=0)
|
| |
|
| |
|
| | mel_padded = torch.nn.utils.rnn.pad_sequence(mels, batch_first=True, padding_value=0.0)
|
| |
|
| | speakers = torch.stack(speakers)
|
| | langs = torch.stack(langs)
|
| | emotions = torch.stack(emotions)
|
| |
|
| | return text_padded, speakers, langs, emotions, mel_padded
|
| |
|