File size: 1,105 Bytes
b6b6742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class SpritesDataset(Dataset):
    def __init__(self, images_path, labels_path, transform, null_context):
        self.images = np.load(images_path, allow_pickle=False)
        self.labels = np.load(labels_path, allow_pickle=False)

        self.images_shape = self.images.shape
        self.labels_shape = self.labels.shape

        self.transform = transform
        self.null_context = null_context

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.transform(self.images[idx])
        
        if self.null_context:
            label = torch.tensor(0).to(torch.int64)
        else:
            label = torch.tensor(self.labels[idx]).to(torch.int64)
        
        return image, label
    
    def __getshape__(self):
        return self.images_shape, self.labels_shape

sprites_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),
                         (0.5,0.5,0.5))
])