import torch import torchaudio import os from torch.utils.data import Dataset from .text_encoder import TextEncoder from .config import HexaConfig class HexaDataset(Dataset): """ Real Dataset Loader for Hexa TTS. Expects a directory structure: /data_root /wavs/ metadata.csv (formatted: filename|text) """ def __init__(self, root_dir, config: HexaConfig, train=True): self.root_dir = root_dir self.config = config self.encoder = TextEncoder() self.wav_dir = os.path.join(root_dir, "wavs") self.metadata_path = os.path.join(root_dir, "metadata.csv") self.files = [] if os.path.exists(self.metadata_path): with open(self.metadata_path, 'r', encoding='utf-8') as f: for line in f: parts = line.strip().split('|') if len(parts) >= 2: self.files.append((parts[0], parts[1])) else: print(f"Warning: Metadata not found at {self.metadata_path}") # Mel Spectrogram Transform self.mel_transform = torchaudio.transforms.MelSpectrogram( sample_rate=config.sample_rate, n_fft=config.n_fft, win_length=config.win_length, hop_length=config.hop_length, n_mels=config.n_mel_channels ) def __len__(self): return len(self.files) def __getitem__(self, idx): filename, text = self.files[idx] wav_path = os.path.join(self.wav_dir, filename + ".wav") # 1. Load Audio waveform, sr = torchaudio.load(wav_path) # Resample if needed if sr != self.config.sample_rate: resampler = torchaudio.transforms.Resample(sr, self.config.sample_rate) waveform = resampler(waveform) # 2. Compute Mel mel = self.mel_transform(waveform) # [channels, frames] mel = mel.squeeze(0).transpose(0, 1) # [frames, channels] # 3. Tokenize Text # Assuming English for starter dataset (LJSpeech) text_ids = self.encoder.preprocess(text, lang_code='en').squeeze(0) # 4. Dummy Speaker/Lang/Emotion for single-speaker dataset speaker = torch.tensor(0) lang = torch.tensor(0) emotion = torch.tensor(0) return text_ids, speaker, lang, emotion, mel def collate_fn(batch): """ Pads batch to longest sequence. """ # Sort by text length for packing (optional but good practice) batch.sort(key=lambda x: x[0].shape[0], reverse=True) text_ids, speakers, langs, emotions, mels = zip(*batch) # Pad Text text_padded = torch.nn.utils.rnn.pad_sequence(text_ids, batch_first=True, padding_value=0) # Pad Mels mel_padded = torch.nn.utils.rnn.pad_sequence(mels, batch_first=True, padding_value=0.0) speakers = torch.stack(speakers) langs = torch.stack(langs) emotions = torch.stack(emotions) return text_padded, speakers, langs, emotions, mel_padded