| import torch | |
| import random | |
| from torch.utils.data import Dataset, DataLoader | |
| class CombinedDataset(Dataset): | |
| """ | |
| A dataset that combines two datasets (TIMIT and EARS), selecting samples based on a probability. | |
| Args: | |
| dataset1 (Dataset): The first dataset (e.g., TIMITDataset). | |
| dataset2 (Dataset): The second dataset (e.g., EARS). | |
| switch_prob (float): Probability of picking from dataset1 (default: 0.5). | |
| """ | |
| def __init__(self, dataset1, dataset2, switch_prob=0.5): | |
| self.dataset1 = dataset1 | |
| self.dataset2 = dataset2 | |
| self.len1 = len(dataset1) | |
| self.len2 = len(dataset2) | |
| self.switch_prob = switch_prob # Probability of picking from dataset1 | |
| def __len__(self): | |
| return max(self.len1, self.len2) # Use the longer dataset length | |
| def __getitem__(self, idx): | |
| # Decide whether to sample from dataset1 or dataset2 | |
| if random.random() < self.switch_prob: | |
| return self.dataset1[idx % self.len1] # Sample from dataset1 | |
| else: | |
| return self.dataset2[idx % self.len2] # Sample from dataset2 |