import torch import torchaudio import pandas as pd import os import soundfile as sf from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence # --- CONFIGURATION --- # We map characters to integers. # We reserve 0 for padding, 1 for 'unknown'. vocab = "_ abcdefghijklmnopqrstuvwxyz'.?" char_to_id = {char: i+2 for i, char in enumerate(vocab)} id_to_char = {i+2: char for i, char in enumerate(vocab)} class TextProcessor: @staticmethod def text_to_sequence(text): text = text.lower() sequence = [char_to_id.get(c, 1) for c in text if c in vocab] return torch.tensor(sequence, dtype=torch.long) class LJSpeechDataset(Dataset): def __init__(self, metadata_path, wavs_dir): """ metadata_path: Path to metadata.csv wavs_dir: Path to the folder containing .wav files """ self.wavs_dir = wavs_dir # Load CSV (Format: ID | Transcription | Normalized Transcription) self.metadata = pd.read_csv(metadata_path, sep='|', header=None, quoting=3).iloc[:100] # Audio Processing Setup (Mel Spectrogram) self.mel_transform = torchaudio.transforms.MelSpectrogram( sample_rate=22050, n_fft=1024, win_length=256, hop_length=256, n_mels=80 # Standard for TTS (Match this with your network.py!) ) def __len__(self): return len(self.metadata) def __getitem__(self, idx): # 1. Get Text row = self.metadata.iloc[idx] file_id = row[0] text = row[2] text_tensor = TextProcessor.text_to_sequence(str(text)) # 2. Get Audio (BYPASSING TORCHAUDIO LOADER) wav_path = os.path.join(self.wavs_dir, f"{file_id}.wav") # Use soundfile directly to read the audio # sf.read returns: audio_array (numpy), sample_rate (int) audio_np, sample_rate = sf.read(wav_path) # Convert Numpy -> PyTorch Tensor # Soundfile gives [time] or [time, channels], but PyTorch wants [channels, time] waveform = torch.from_numpy(audio_np).float() if waveform.dim() == 1: # If mono, add channel dimension: [time] -> [1, time] waveform = waveform.unsqueeze(0) else: # If stereo, transpose: [time, channels] -> [channels, time] waveform = waveform.transpose(0, 1) # Resample if necessary if sample_rate != 22050: resampler = torchaudio.transforms.Resample(sample_rate, 22050) waveform = resampler(waveform) # Convert to Mel Spectrogram mel_spec = self.mel_transform(waveform).squeeze(0) mel_spec = mel_spec.transpose(0, 1) return text_tensor, mel_spec # --- BATCHING MAGIC (Collate Function) --- # Since sentences have different lengths, we must pad them to match the longest in the batch. def collate_fn_tts(batch): # batch is a list of tuples: [(text1, mel1), (text2, mel2), ...] # Separate text and mels text_list = [item[0] for item in batch] mel_list = [item[1] for item in batch] # Pad sequences # batch_first=True makes output [batch, max_len, ...] text_padded = pad_sequence(text_list, batch_first=True, padding_value=0) mel_padded = pad_sequence(mel_list, batch_first=True, padding_value=0.0) return text_padded, mel_padded # --- SANITY CHECK --- if __name__ == "__main__": # UPDATE THESE PATHS TO MATCH YOUR FOLDER BASE_PATH = "LJSpeech-1.1" csv_path = os.path.join(BASE_PATH, "metadata.csv") wav_path = os.path.join(BASE_PATH, "wavs") if os.path.exists(csv_path): print("Loading Dataset...") dataset = LJSpeechDataset(csv_path, wav_path) loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn_tts) # Get one batch text_batch, mel_batch = next(iter(loader)) print(f"Text Batch Shape: {text_batch.shape} (Batch, Max Text Len)") print(f"Mel Batch Shape: {mel_batch.shape} (Batch, Max Audio Len, 80)") print("\nSUCCESS: Data pipeline is working!") else: print("Dataset not found. Please download LJSpeech-1.1 to run this test.")