Spaces:
Running on Zero
Running on Zero
File size: 2,102 Bytes
52d0a0e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | import numpy as np
from torch.utils.data import Dataset
import torch
class EmbeddingFile(Dataset):
"""
modified from: https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder
uses cached directory listing if available rather than walking directory
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(self, file, loading_imgs, num_limit=60000):
super(Dataset, self).__init__()
self.file = file
loaded = np.load(file)
self.feats = loaded["feats"][:num_limit]
self.labels = loaded["labels"][:num_limit]
if loading_imgs:
self.imgs = loaded["imgs"][:num_limit]
else:
self.imgs = [0] * len(self.labels)
def dim(self):
return self.feats.shape[1]
def num_classes(self):
return self.labels.max() + 1
def __getitem__(self, index):
return self.imgs[index], self.feats[index], self.labels[index]
def __len__(self):
return len(self.labels)
class EmbeddingAndImage(Dataset):
def __init__(self, file, dataset):
super(Dataset, self).__init__()
self.file = file
loaded = np.load(file)
self.feats = loaded["feats"]
self.labels = loaded["labels"]
num_imgs = len(dataset)
img_shape = dataset[0]["img"].shape
self.imgs = torch.empty((num_imgs, *img_shape))
# if dataset[0] is a dict, then only use the "img" key to create a list
for i, d in enumerate(dataset):
self.imgs[i] = d["img"]
### NOTE: TOO SLOW...
def dim(self):
return self.feats.shape[1]
def num_classes(self):
return self.labels.max() + 1
def __getitem__(self, index):
return self.feats[index], self.labels[index], self.imgs[index]
def __len__(self):
return len(self.labels)
|