Spaces:
Runtime error
Runtime error
| import torch | |
| import torchaudio | |
| import soundfile as sf # Direct import (Cleaned) | |
| import os | |
| from torch.utils.data import Dataset | |
| from .text_encoder import TextEncoder | |
| from .config import HexaConfig | |
| class HexaDataset(Dataset): | |
| """ | |
| Real Dataset Loader for Hexa TTS. | |
| Expects a directory structure: | |
| /data_root | |
| /wavs/ | |
| metadata.csv (formatted: filename|text) | |
| """ | |
| def __init__(self, root_dir, config: HexaConfig, train=True): | |
| self.root_dir = root_dir | |
| self.config = config | |
| self.encoder = TextEncoder() | |
| self.wav_dir = os.path.join(root_dir, "wavs") | |
| self.metadata_path = os.path.join(root_dir, "metadata.csv") | |
| self.files = [] | |
| if os.path.exists(self.metadata_path): | |
| with open(self.metadata_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| parts = line.strip().split('|') | |
| if len(parts) >= 2: | |
| self.files.append((parts[0], parts[1])) | |
| else: | |
| print(f"Warning: Metadata not found at {self.metadata_path}") | |
| # Mel Spectrogram Transform | |
| self.mel_transform = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=config.sample_rate, | |
| n_fft=config.n_fft, | |
| win_length=config.win_length, | |
| hop_length=config.hop_length, | |
| n_mels=config.n_mel_channels | |
| ) | |
| def __len__(self): | |
| return len(self.files) | |
| def __getitem__(self, idx): | |
| filename, text = self.files[idx] | |
| wav_path = os.path.join(self.wav_dir, filename + ".wav") | |
| # 1. Load Audio using SoundFile (Direct Read) | |
| # Avoids torchaudio backend dependency issues | |
| wav, sr = sf.read(wav_path) | |
| waveform = torch.from_numpy(wav).float() | |
| # soundfile returns [samples, channels] or [samples] | |
| # torchaudio expects [channels, samples] | |
| if waveform.dim() == 1: | |
| waveform = waveform.unsqueeze(0) # [1, samples] | |
| else: | |
| waveform = waveform.transpose(0, 1) # [channels, samples] | |
| # Resample if needed | |
| if sr != self.config.sample_rate: | |
| resampler = torchaudio.transforms.Resample(sr, self.config.sample_rate) | |
| waveform = resampler(waveform) | |
| # 2. Compute Mel | |
| mel = self.mel_transform(waveform) # [channels, frames] | |
| mel = mel.squeeze(0).transpose(0, 1) # [frames, channels] | |
| # 3. Tokenize Text | |
| # Assuming English for starter dataset (LJSpeech) | |
| text_ids = self.encoder.preprocess(text, lang_code='en').squeeze(0) | |
| # 4. Dummy Speaker/Lang/Emotion for single-speaker dataset | |
| speaker = torch.tensor(0) | |
| lang = torch.tensor(0) | |
| emotion = torch.tensor(0) | |
| return text_ids, speaker, lang, emotion, mel | |
| def collate_fn(batch): | |
| """ | |
| Pads batch to longest sequence. | |
| """ | |
| # Sort by text length for packing (optional but good practice) | |
| batch.sort(key=lambda x: x[0].shape[0], reverse=True) | |
| text_ids, speakers, langs, emotions, mels = zip(*batch) | |
| # Pad Text | |
| text_padded = torch.nn.utils.rnn.pad_sequence(text_ids, batch_first=True, padding_value=0) | |
| # Pad Mels | |
| mel_padded = torch.nn.utils.rnn.pad_sequence(mels, batch_first=True, padding_value=0.0) | |
| speakers = torch.stack(speakers) | |
| langs = torch.stack(langs) | |
| emotions = torch.stack(emotions) | |
| return text_padded, speakers, langs, emotions, mel_padded | |
| # FORCE UPLOAD HASH BUSTER | |
| __version_fix__ = "1.0.clean" | |