Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| import torchvision.transforms as transforms | |
| class SpritesDataset(Dataset): | |
| def __init__(self, images_path, labels_path, transform, null_context): | |
| self.images = np.load(images_path, allow_pickle=False) | |
| self.labels = np.load(labels_path, allow_pickle=False) | |
| self.images_shape = self.images.shape | |
| self.labels_shape = self.labels.shape | |
| self.transform = transform | |
| self.null_context = null_context | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self, idx): | |
| image = self.transform(self.images[idx]) | |
| if self.null_context: | |
| label = torch.tensor(0).to(torch.int64) | |
| else: | |
| label = torch.tensor(self.labels[idx]).to(torch.int64) | |
| return image, label | |
| def __getshape__(self): | |
| return self.images_shape, self.labels_shape | |
| sprites_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,0.5,0.5), | |
| (0.5,0.5,0.5)) | |
| ]) | |