Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from torch.utils.data import Dataset | |
| from .builder import DATASETS, build_dataset | |
| class AdversarialDataset(Dataset): | |
| """Mix Dataset for the adversarial training in 3D human mesh estimation | |
| task. | |
| The dataset combines data from two datasets and | |
| return a dict containing data from two datasets. | |
| Args: | |
| train_dataset (:obj:`Dataset`): Dataset for 3D human mesh estimation. | |
| adv_dataset (:obj:`Dataset`): Dataset for adversarial learning. | |
| """ | |
| def __init__(self, train_dataset: Dataset, adv_dataset: Dataset): | |
| super().__init__() | |
| self.train_dataset = build_dataset(train_dataset) | |
| self.adv_dataset = build_dataset(adv_dataset) | |
| self.num_train_data = len(self.train_dataset) | |
| self.num_adv_data = len(self.adv_dataset) | |
| def __len__(self): | |
| """Get the size of the dataset.""" | |
| return self.num_train_data | |
| def __getitem__(self, idx: int): | |
| """Given index, get the data from train dataset and randomly sample an | |
| item from adversarial dataset. | |
| Return a dict containing data from train and adversarial dataset. | |
| """ | |
| data = self.train_dataset[idx] | |
| adv_idx = np.random.randint(low=0, high=self.num_adv_data, dtype=int) | |
| adv_data = self.adv_dataset[adv_idx] | |
| for k, v in adv_data.items(): | |
| data['adv_' + k] = v | |
| return data | |