| from torch.utils.data import Dataset | |
| import torch | |
| class PoisonDataset(Dataset): | |
| def __init__(self, dataset, indices, target): | |
| self.dataset = dataset | |
| self.indices = [int(i) for i in indices] | |
| self.target = target | |
| def __len__(self): | |
| return len(self.indices) | |
| def __getitem__(self, item): | |
| _, x, y = self.dataset[self.indices[item]] | |
| # print(type(y)) | |
| # print(y) | |
| # print(x.shape) | |
| # y = torch.tensor(self.target) | |
| # print(y.shape) | |
| # print(type(self.target)) | |
| return x, self.target | |