File size: 3,544 Bytes
f55a095 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | import torch
from torch.utils.data import Dataset
import json
import torchaudio
import os
from typing import Optional, Dict, Any, List, Tuple
import pandas as pd
import warnings
import random
class TIMITDataset(Dataset):
"""
TIMIT dataset class that loads audio and associated metadata/transcriptions.
Args:
json_path (str): Path to the JSON file containing TIMIT data
timit_root (str): Root directory containing TIMIT audio files
sample_rate (int, optional): Target sample rate for audio. Defaults to 16000.
normalize_audio (bool, optional): Whether to normalize audio. Defaults to True.
Returns:
Dict containing:
- audio_tensor: torch.Tensor of shape (1, num_samples)
- speaker_id: str, speaker identifier
- metadata: dict containing speaker metadata
- prompts: list of prompts used
- responses: list of responses generated
- filepath: str, path to audio file
- phonemes: DataFrame with columns [start_sample, end_sample, phoneme]
- words: DataFrame with columns [start_sample, end_sample, word]
- text: str, complete transcription
"""
def __init__(
self,
json_path: str,
timit_root: str,
sample_rate: int = 16000,
normalize_audio: bool = True
):
super().__init__()
# Load the JSON data
with open(json_path, 'r') as f:
self.data = json.load(f)
self.timit_root = timit_root
self.sample_rate = sample_rate
self.normalize_audio = normalize_audio
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, Any]:
# Get sample data
sample = self.data[idx]
# Get file paths
audio_path = os.path.join(self.timit_root, sample['audio_path'])
# Load audio first
audio, sr = torchaudio.load(audio_path)
if sr != self.sample_rate:
audio = torchaudio.transforms.Resample(sr, self.sample_rate)(audio)
mean = torch.mean(audio)
std = torch.std(audio)
audio = (audio - mean) / (std + 1e-8)
# Get total number of samples
num_samples = audio.shape[1]
num_samples_3s = 3 * self.sample_rate # Samples for 3 seconds
# Ensure the audio is at least 3 seconds long
if num_samples >= num_samples_3s:
start_sample = random.randint(0, num_samples - num_samples_3s)
end_sample = start_sample + num_samples_3s
audio = audio[:, start_sample:end_sample]
else:
# If audio is shorter than 3 seconds, pad it
pad_size = num_samples_3s - num_samples
audio = torch.nn.functional.pad(audio, (0, pad_size))
prompts = sample.get('prompts', [])
answers = sample.get('responses', [])
if prompts and answers and len(prompts) == len(answers):
rand_idx = random.randint(0, len(prompts) - 1)
prompt = prompts[rand_idx]
answer = answers[rand_idx].replace("\n", " ").strip() # Clean response
else:
prompt = None
answer = None
return {
'audio_tensor': audio,
'sid': sample['speaker']['id'],
'prompt': prompt,
'answer': answer,
'filename': audio_path,
} |