| | import torch |
| | from torch.utils.data import Dataset |
| | import json |
| | import torchaudio |
| | import os |
| | from typing import Optional, Dict, Any, List, Tuple |
| | import pandas as pd |
| | import warnings |
| | import random |
| |
|
| | class TIMITDataset(Dataset): |
| | """ |
| | TIMIT dataset class that loads audio and associated metadata/transcriptions. |
| | |
| | Args: |
| | json_path (str): Path to the JSON file containing TIMIT data |
| | timit_root (str): Root directory containing TIMIT audio files |
| | sample_rate (int, optional): Target sample rate for audio. Defaults to 16000. |
| | normalize_audio (bool, optional): Whether to normalize audio. Defaults to True. |
| | |
| | Returns: |
| | Dict containing: |
| | - audio_tensor: torch.Tensor of shape (1, num_samples) |
| | - speaker_id: str, speaker identifier |
| | - metadata: dict containing speaker metadata |
| | - prompts: list of prompts used |
| | - responses: list of responses generated |
| | - filepath: str, path to audio file |
| | - phonemes: DataFrame with columns [start_sample, end_sample, phoneme] |
| | - words: DataFrame with columns [start_sample, end_sample, word] |
| | - text: str, complete transcription |
| | """ |
| | def __init__( |
| | self, |
| | json_path: str, |
| | timit_root: str, |
| | sample_rate: int = 16000, |
| | normalize_audio: bool = True |
| | ): |
| | super().__init__() |
| | |
| | |
| | with open(json_path, 'r') as f: |
| | self.data = json.load(f) |
| | |
| | self.timit_root = timit_root |
| | self.sample_rate = sample_rate |
| | self.normalize_audio = normalize_audio |
| |
|
| | def __len__(self) -> int: |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx: int) -> Dict[str, Any]: |
| | |
| | sample = self.data[idx] |
| | |
| | |
| | audio_path = os.path.join(self.timit_root, sample['audio_path']) |
| | |
| | |
| | audio, sr = torchaudio.load(audio_path) |
| | |
| | |
| | if sr != self.sample_rate: |
| | audio = torchaudio.transforms.Resample(sr, self.sample_rate)(audio) |
| | |
| | mean = torch.mean(audio) |
| | std = torch.std(audio) |
| | audio = (audio - mean) / (std + 1e-8) |
| | |
| | |
| | num_samples = audio.shape[1] |
| | num_samples_3s = 3 * self.sample_rate |
| | |
| | |
| | if num_samples >= num_samples_3s: |
| | start_sample = random.randint(0, num_samples - num_samples_3s) |
| | end_sample = start_sample + num_samples_3s |
| | audio = audio[:, start_sample:end_sample] |
| | else: |
| | |
| | pad_size = num_samples_3s - num_samples |
| | audio = torch.nn.functional.pad(audio, (0, pad_size)) |
| | |
| | prompts = sample.get('prompts', []) |
| | answers = sample.get('responses', []) |
| | |
| | if prompts and answers and len(prompts) == len(answers): |
| | rand_idx = random.randint(0, len(prompts) - 1) |
| | prompt = prompts[rand_idx] |
| | answer = answers[rand_idx].replace("\n", " ").strip() |
| | else: |
| | prompt = None |
| | answer = None |
| | |
| | return { |
| | 'audio_tensor': audio, |
| | 'sid': sample['speaker']['id'], |
| | 'prompt': prompt, |
| | 'answer': answer, |
| | 'filename': audio_path, |
| | } |