SAE / attacks /CleanSheet /poison_dataset.py
Ttius's picture
Upload 192 files
998bb30 verified
from torch.utils.data import Dataset
import torch
class PoisonDataset(Dataset):
def __init__(self, dataset, indices, target):
self.dataset = dataset
self.indices = [int(i) for i in indices]
self.target = target
def __len__(self):
return len(self.indices)
def __getitem__(self, item):
_, x, y = self.dataset[self.indices[item]]
# print(type(y))
# print(y)
# print(x.shape)
# y = torch.tensor(self.target)
# print(y.shape)
# print(type(self.target))
return x, self.target