Spaces:
Sleeping
Sleeping
| import torch | |
| import torchaudio | |
| import pandas as pd | |
| import os | |
| import soundfile as sf | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.nn.utils.rnn import pad_sequence | |
| # --- CONFIGURATION --- | |
| # We map characters to integers. | |
| # We reserve 0 for padding, 1 for 'unknown'. | |
| vocab = "_ abcdefghijklmnopqrstuvwxyz'.?" | |
| char_to_id = {char: i+2 for i, char in enumerate(vocab)} | |
| id_to_char = {i+2: char for i, char in enumerate(vocab)} | |
| class TextProcessor: | |
| def text_to_sequence(text): | |
| text = text.lower() | |
| sequence = [char_to_id.get(c, 1) for c in text if c in vocab] | |
| return torch.tensor(sequence, dtype=torch.long) | |
| class LJSpeechDataset(Dataset): | |
| def __init__(self, metadata_path, wavs_dir): | |
| """ | |
| metadata_path: Path to metadata.csv | |
| wavs_dir: Path to the folder containing .wav files | |
| """ | |
| self.wavs_dir = wavs_dir | |
| # Load CSV (Format: ID | Transcription | Normalized Transcription) | |
| self.metadata = pd.read_csv(metadata_path, sep='|', header=None, quoting=3).iloc[:100] | |
| # Audio Processing Setup (Mel Spectrogram) | |
| self.mel_transform = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=22050, | |
| n_fft=1024, | |
| win_length=256, | |
| hop_length=256, | |
| n_mels=80 # Standard for TTS (Match this with your network.py!) | |
| ) | |
| def __len__(self): | |
| return len(self.metadata) | |
| def __getitem__(self, idx): | |
| # 1. Get Text | |
| row = self.metadata.iloc[idx] | |
| file_id = row[0] | |
| text = row[2] | |
| text_tensor = TextProcessor.text_to_sequence(str(text)) | |
| # 2. Get Audio (BYPASSING TORCHAUDIO LOADER) | |
| wav_path = os.path.join(self.wavs_dir, f"{file_id}.wav") | |
| # Use soundfile directly to read the audio | |
| # sf.read returns: audio_array (numpy), sample_rate (int) | |
| audio_np, sample_rate = sf.read(wav_path) | |
| # Convert Numpy -> PyTorch Tensor | |
| # Soundfile gives [time] or [time, channels], but PyTorch wants [channels, time] | |
| waveform = torch.from_numpy(audio_np).float() | |
| if waveform.dim() == 1: | |
| # If mono, add channel dimension: [time] -> [1, time] | |
| waveform = waveform.unsqueeze(0) | |
| else: | |
| # If stereo, transpose: [time, channels] -> [channels, time] | |
| waveform = waveform.transpose(0, 1) | |
| # Resample if necessary | |
| if sample_rate != 22050: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 22050) | |
| waveform = resampler(waveform) | |
| # Convert to Mel Spectrogram | |
| mel_spec = self.mel_transform(waveform).squeeze(0) | |
| mel_spec = mel_spec.transpose(0, 1) | |
| return text_tensor, mel_spec | |
| # --- BATCHING MAGIC (Collate Function) --- | |
| # Since sentences have different lengths, we must pad them to match the longest in the batch. | |
| def collate_fn_tts(batch): | |
| # batch is a list of tuples: [(text1, mel1), (text2, mel2), ...] | |
| # Separate text and mels | |
| text_list = [item[0] for item in batch] | |
| mel_list = [item[1] for item in batch] | |
| # Pad sequences | |
| # batch_first=True makes output [batch, max_len, ...] | |
| text_padded = pad_sequence(text_list, batch_first=True, padding_value=0) | |
| mel_padded = pad_sequence(mel_list, batch_first=True, padding_value=0.0) | |
| return text_padded, mel_padded | |
| # --- SANITY CHECK --- | |
| if __name__ == "__main__": | |
| # UPDATE THESE PATHS TO MATCH YOUR FOLDER | |
| BASE_PATH = "LJSpeech-1.1" | |
| csv_path = os.path.join(BASE_PATH, "metadata.csv") | |
| wav_path = os.path.join(BASE_PATH, "wavs") | |
| if os.path.exists(csv_path): | |
| print("Loading Dataset...") | |
| dataset = LJSpeechDataset(csv_path, wav_path) | |
| loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn_tts) | |
| # Get one batch | |
| text_batch, mel_batch = next(iter(loader)) | |
| print(f"Text Batch Shape: {text_batch.shape} (Batch, Max Text Len)") | |
| print(f"Mel Batch Shape: {mel_batch.shape} (Batch, Max Audio Len, 80)") | |
| print("\nSUCCESS: Data pipeline is working!") | |
| else: | |
| print("Dataset not found. Please download LJSpeech-1.1 to run this test.") |