File size: 4,422 Bytes
be29b5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torchaudio
import pandas as pd
import os
import soundfile as sf
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence


# --- CONFIGURATION ---
# We map characters to integers. 
# We reserve 0 for padding, 1 for 'unknown'.
vocab = "_ abcdefghijklmnopqrstuvwxyz'.?"
char_to_id = {char: i+2 for i, char in enumerate(vocab)}
id_to_char = {i+2: char for i, char in enumerate(vocab)}

class TextProcessor:
    @staticmethod
    def text_to_sequence(text):
        text = text.lower()
        sequence = [char_to_id.get(c, 1) for c in text if c in vocab]
        return torch.tensor(sequence, dtype=torch.long)

class LJSpeechDataset(Dataset):
    def __init__(self, metadata_path, wavs_dir):
        """

        metadata_path: Path to metadata.csv

        wavs_dir: Path to the folder containing .wav files

        """
        self.wavs_dir = wavs_dir
        # Load CSV (Format: ID | Transcription | Normalized Transcription)
        self.metadata = pd.read_csv(metadata_path, sep='|', header=None, quoting=3).iloc[:100]
        
        # Audio Processing Setup (Mel Spectrogram)
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=22050,
            n_fft=1024,
            win_length=256,
            hop_length=256,
            n_mels=80     # Standard for TTS (Match this with your network.py!)
        )

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        # 1. Get Text
        row = self.metadata.iloc[idx]
        file_id = row[0]
        text = row[2] 
        text_tensor = TextProcessor.text_to_sequence(str(text))
        
        # 2. Get Audio (BYPASSING TORCHAUDIO LOADER)
        wav_path = os.path.join(self.wavs_dir, f"{file_id}.wav")
        
        # Use soundfile directly to read the audio
        # sf.read returns: audio_array (numpy), sample_rate (int)
        audio_np, sample_rate = sf.read(wav_path)
        
        # Convert Numpy -> PyTorch Tensor
        # Soundfile gives [time] or [time, channels], but PyTorch wants [channels, time]
        waveform = torch.from_numpy(audio_np).float()
        
        if waveform.dim() == 1:
            # If mono, add channel dimension: [time] -> [1, time]
            waveform = waveform.unsqueeze(0)
        else:
            # If stereo, transpose: [time, channels] -> [channels, time]
            waveform = waveform.transpose(0, 1)
        
        # Resample if necessary
        if sample_rate != 22050:
            resampler = torchaudio.transforms.Resample(sample_rate, 22050)
            waveform = resampler(waveform)
        
        # Convert to Mel Spectrogram
        mel_spec = self.mel_transform(waveform).squeeze(0)
        mel_spec = mel_spec.transpose(0, 1)
        
        return text_tensor, mel_spec

# --- BATCHING MAGIC (Collate Function) ---
# Since sentences have different lengths, we must pad them to match the longest in the batch.
def collate_fn_tts(batch):
    # batch is a list of tuples: [(text1, mel1), (text2, mel2), ...]
    
    # Separate text and mels
    text_list = [item[0] for item in batch]
    mel_list = [item[1] for item in batch]
    
    # Pad sequences
    # batch_first=True makes output [batch, max_len, ...]
    text_padded = pad_sequence(text_list, batch_first=True, padding_value=0)
    mel_padded = pad_sequence(mel_list, batch_first=True, padding_value=0.0)
    
    return text_padded, mel_padded

# --- SANITY CHECK ---
if __name__ == "__main__":
    # UPDATE THESE PATHS TO MATCH YOUR FOLDER
    BASE_PATH = "LJSpeech-1.1" 
    csv_path = os.path.join(BASE_PATH, "metadata.csv")
    wav_path = os.path.join(BASE_PATH, "wavs")
    
    if os.path.exists(csv_path):
        print("Loading Dataset...")
        dataset = LJSpeechDataset(csv_path, wav_path)
        loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn_tts)
        
        # Get one batch
        text_batch, mel_batch = next(iter(loader))
        
        print(f"Text Batch Shape: {text_batch.shape} (Batch, Max Text Len)")
        print(f"Mel Batch Shape:  {mel_batch.shape}  (Batch, Max Audio Len, 80)")
        print("\nSUCCESS: Data pipeline is working!")
    else:
        print("Dataset not found. Please download LJSpeech-1.1 to run this test.")