| | import torch |
| | from torch.utils.data import Dataset |
| | import json |
| | import torchaudio |
| | import os |
| | from typing import Optional, Dict, Any, List, Tuple |
| | import pandas as pd |
| | import warnings |
| | import random |
| | from pathlib import Path |
| | from collections import defaultdict |
| |
|
| |
|
| |
|
| |
|
| | class TEARSDataset(Dataset): |
| | """ |
| | TEARS dataset class that loads audio and associated metadata/responses. |
| | |
| | Args: |
| | json_path (str): Path to the JSON file containing TEARS data |
| | tears_root (str): Root directory containing TEARS audio files |
| | sample_rate (int, optional): Target sample rate for audio. Defaults to 16000. |
| | duration (float, optional): Target duration in seconds. Defaults to 3.0. |
| | normalize_audio (bool, optional): Whether to normalize audio. Defaults to True. |
| | |
| | Returns: |
| | Dict containing: |
| | - audio_tensor: torch.Tensor of shape (1, num_samples) |
| | - speaker_id: str, speaker identifier |
| | - metadata: dict containing speaker metadata |
| | - prompt: str, randomly selected prompt |
| | - response: str, corresponding response |
| | - filepath: str, path to audio file |
| | """ |
| | def __init__( |
| | self, |
| | json_path: str, |
| | tears_root: str, |
| | sample_rate: int = 16000, |
| | duration: float = 3.0, |
| | normalize_audio: bool = True, |
| | augment: bool = True |
| | ): |
| | super().__init__() |
| | |
| | |
| | with open(json_path, 'r') as f: |
| | self.data = json.load(f) |
| | |
| | self.tears_root = Path(tears_root) |
| | self.sample_rate = sample_rate |
| | self.duration = duration |
| | self.normalize_audio = normalize_audio |
| | self.target_samples = int(duration * sample_rate) |
| | self.augment = augment |
| |
|
| | def __len__(self) -> int: |
| | return len(self.data) |
| | |
| | def augment_audio(self, waveform, sample_rate): |
| | |
| | augmentation_choices = ['time_stretch', 'pitch_shift', 'add_noise', 'spec_aug'] |
| | random.shuffle(augmentation_choices) |
| |
|
| | for aug in augmentation_choices[:random.randint(1, len(augmentation_choices))]: |
| | if aug == 'time_stretch': |
| | rate = random.uniform(0.8, 1.25) |
| | effect = [['speed', str(rate)], ['rate', str(16000)]] |
| | waveform, _ = torchaudio.sox_effects.apply_effects_tensor( |
| | waveform, 16000, effects=effect |
| | ) |
| |
|
| | elif aug == 'pitch_shift': |
| | n_steps = random.randint(-4, 4) |
| | effect = [['pitch', str(n)] for n in [n_steps*100 for n in [random.choice([-2, -1, 1, 2])]]] |
| | waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, 16000, effect) |
| |
|
| | elif aug == 'add_noise': |
| | noise = torch.randn_like(waveform) * random.uniform(0.001, 0.015) |
| | waveform = waveform + noise |
| |
|
| | elif aug == 'frequency_mask': |
| | freq_mask = T.FrequencyMasking(freq_mask_param=random.randint(15, 30)) |
| | waveform = freq_mask(waveform) |
| |
|
| | elif aug == 'time_mask': |
| | time_mask = T.TimeMasking(time_mask_param=random.randint(20, 80)) |
| | waveform = time_mask(waveform) |
| |
|
| | elif aug == 'reverb': |
| | effect = [['reverb', '-w', str(random.randint(10, 50))]] |
| | waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, 16000, effect) |
| |
|
| | elif aug == 'pitch_shift': |
| | steps = random.randint(-2, 2) |
| | effect = [['pitch', str(steps * 100)], ['rate', '16000']] |
| | waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, 16000, effect) |
| |
|
| | return waveform |
| |
|
| | def __getitem__(self, idx: int) -> Dict[str, Any]: |
| | |
| | sample = self.data[idx] |
| | |
| | |
| | audio_path = str(self.tears_root / sample['audio_path']) |
| | |
| | |
| | try: |
| | audio, sr = torchaudio.load(audio_path) |
| | |
| | |
| | if sr != self.sample_rate: |
| | audio = torchaudio.transforms.Resample(sr, self.sample_rate)(audio) |
| | |
| | if self.augment: |
| | audio = self.augment_audio(audio, self.sample_rate) |
| | |
| | |
| | if self.normalize_audio: |
| | mean = torch.mean(audio) |
| | std = torch.std(audio) |
| | audio = (audio - mean) / (std + 1e-8) |
| | |
| | |
| | num_samples = audio.shape[1] |
| | |
| | if num_samples >= self.target_samples: |
| | |
| | start_sample = random.randint(0, num_samples - self.target_samples) |
| | audio = audio[:, start_sample:start_sample + self.target_samples] |
| | else: |
| | |
| | pad_size = self.target_samples - num_samples |
| | audio = torch.nn.functional.pad(audio, (0, pad_size)) |
| | |
| | except Exception as e: |
| | warnings.warn(f"Error loading audio file {audio_path}: {str(e)}") |
| | |
| | audio = torch.zeros(1, self.target_samples) |
| | |
| | |
| | prompts = sample.get('prompts', []) |
| | responses = sample.get('responses', []) |
| | |
| | if prompts and responses and len(prompts) == len(responses): |
| | rand_idx = random.randint(0, len(prompts) - 1) |
| | prompt = prompts[rand_idx] |
| | response = responses[rand_idx].replace("\n", " ").strip() |
| | else: |
| | prompt = None |
| | response = None |
| | |
| | return { |
| | 'audio_tensor': audio, |
| | 'sid': sample['speaker']['id'], |
| | 'metadata': sample['speaker'], |
| | 'prompt': prompt, |
| | 'answer': response, |
| | 'filename': str(audio_path) |
| | } |
| |
|
| |
|
| |
|
| | @staticmethod |
| | def redistribute_speakers( |
| | json_paths: Dict[str, str], |
| | split_ratios: Dict[str, float], |
| | seed: int = 42 |
| | ) -> Dict[str, List[Dict]]: |
| | """ |
| | Redistribute speakers across splits according to given ratios. |
| | |
| | Args: |
| | json_paths: Dict mapping split names to json file paths |
| | split_ratios: Dict mapping split names to desired ratios (should sum to 1) |
| | seed: Random seed for reproducibility |
| | |
| | Returns: |
| | Dict mapping split names to lists of samples |
| | """ |
| | random.seed(seed) |
| | |
| | |
| | speaker_samples = defaultdict(list) |
| | for split, path in json_paths.items(): |
| | with open(path, 'r') as f: |
| | data = json.load(f) |
| | for sample in data: |
| | speaker_samples[sample['speaker']['id']].append(sample) |
| | |
| | |
| | all_speakers = list(speaker_samples.keys()) |
| | random.shuffle(all_speakers) |
| | |
| | |
| | total_speakers = len(all_speakers) |
| | split_speakers = { |
| | split: int(ratio * total_speakers) |
| | for split, ratio in split_ratios.items() |
| | } |
| | |
| | |
| | remainder = total_speakers - sum(split_speakers.values()) |
| | if remainder > 0: |
| | |
| | split_speakers[list(split_speakers.keys())[0]] += remainder |
| | |
| | |
| | new_splits = defaultdict(list) |
| | current_idx = 0 |
| | |
| | for split, num_speakers in split_speakers.items(): |
| | split_speaker_ids = all_speakers[current_idx:current_idx + num_speakers] |
| | for speaker_id in split_speaker_ids: |
| | new_splits[split].extend(speaker_samples[speaker_id]) |
| | current_idx += num_speakers |
| | |
| | return new_splits |
| |
|
| | @staticmethod |
| | def save_splits(splits: Dict[str, List[Dict]], output_dir: str): |
| | """Save redistributed splits to JSON files.""" |
| | output_dir = Path(output_dir) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | for split_name, samples in splits.items(): |
| | output_path = output_dir / f"tears_dataset_{split_name}_with_responses.json" |
| | with open(output_path, 'w') as f: |
| | json.dump(samples, f, indent=2) |
| |
|
| |
|