hexa-tts-5b / src /dataset.py
Hexa09's picture
Upload folder using huggingface_hub
e729286 verified
import torch
import torchaudio
import os
from torch.utils.data import Dataset
from .text_encoder import TextEncoder
from .config import HexaConfig
class HexaDataset(Dataset):
"""
Real Dataset Loader for Hexa TTS.
Expects a directory structure:
/data_root
/wavs/
metadata.csv (formatted: filename|text)
"""
def __init__(self, root_dir, config: HexaConfig, train=True):
self.root_dir = root_dir
self.config = config
self.encoder = TextEncoder()
self.wav_dir = os.path.join(root_dir, "wavs")
self.metadata_path = os.path.join(root_dir, "metadata.csv")
self.files = []
if os.path.exists(self.metadata_path):
with open(self.metadata_path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split('|')
if len(parts) >= 2:
self.files.append((parts[0], parts[1]))
else:
print(f"Warning: Metadata not found at {self.metadata_path}")
# Mel Spectrogram Transform
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=config.sample_rate,
n_fft=config.n_fft,
win_length=config.win_length,
hop_length=config.hop_length,
n_mels=config.n_mel_channels
)
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
filename, text = self.files[idx]
wav_path = os.path.join(self.wav_dir, filename + ".wav")
# 1. Load Audio
waveform, sr = torchaudio.load(wav_path)
# Resample if needed
if sr != self.config.sample_rate:
resampler = torchaudio.transforms.Resample(sr, self.config.sample_rate)
waveform = resampler(waveform)
# 2. Compute Mel
mel = self.mel_transform(waveform) # [channels, frames]
mel = mel.squeeze(0).transpose(0, 1) # [frames, channels]
# 3. Tokenize Text
# Assuming English for starter dataset (LJSpeech)
text_ids = self.encoder.preprocess(text, lang_code='en').squeeze(0)
# 4. Dummy Speaker/Lang/Emotion for single-speaker dataset
speaker = torch.tensor(0)
lang = torch.tensor(0)
emotion = torch.tensor(0)
return text_ids, speaker, lang, emotion, mel
def collate_fn(batch):
"""
Pads batch to longest sequence.
"""
# Sort by text length for packing (optional but good practice)
batch.sort(key=lambda x: x[0].shape[0], reverse=True)
text_ids, speakers, langs, emotions, mels = zip(*batch)
# Pad Text
text_padded = torch.nn.utils.rnn.pad_sequence(text_ids, batch_first=True, padding_value=0)
# Pad Mels
mel_padded = torch.nn.utils.rnn.pad_sequence(mels, batch_first=True, padding_value=0.0)
speakers = torch.stack(speakers)
langs = torch.stack(langs)
emotions = torch.stack(emotions)
return text_padded, speakers, langs, emotions, mel_padded