File size: 2,102 Bytes
52d0a0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
from torch.utils.data import Dataset
import torch

class EmbeddingFile(Dataset):
    """
    modified from: https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder
    uses cached directory listing if available rather than walking directory
     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        samples (list): List of (sample path, class_index) tuples
        targets (list): The class_index value for each image in the dataset
    """

    def __init__(self, file, loading_imgs, num_limit=60000):
        super(Dataset, self).__init__()
        self.file = file
        loaded = np.load(file)
        self.feats = loaded["feats"][:num_limit]
        self.labels = loaded["labels"][:num_limit]
        if loading_imgs:
            self.imgs = loaded["imgs"][:num_limit]
        else:
            self.imgs = [0] * len(self.labels)

    def dim(self):
        return self.feats.shape[1]

    def num_classes(self):
        return self.labels.max() + 1

    def __getitem__(self, index):
        return self.imgs[index], self.feats[index], self.labels[index]

    def __len__(self):
        return len(self.labels)


class EmbeddingAndImage(Dataset):
    def __init__(self, file, dataset):
        super(Dataset, self).__init__()
        self.file = file
        loaded = np.load(file)
        self.feats = loaded["feats"]
        self.labels = loaded["labels"]
        num_imgs = len(dataset)
        img_shape = dataset[0]["img"].shape
        self.imgs = torch.empty((num_imgs, *img_shape))
        # if dataset[0] is a dict, then only use the "img" key to create a list
        for i, d in enumerate(dataset):
            self.imgs[i] = d["img"]
        ### NOTE: TOO SLOW...

    def dim(self):
        return self.feats.shape[1]

    def num_classes(self):
        return self.labels.max() + 1

    def __getitem__(self, index):
        return self.feats[index], self.labels[index], self.imgs[index]

    def __len__(self):
        return len(self.labels)