cods-live / dataset.py
Léo Andéol
More
9c3ff92 unverified
import json
import urllib.request
import torchvision.transforms as T
from cods.classif.data import ClassificationDataset
from PIL import Image
class DatasetWrapper(ClassificationDataset):
def __init__(self, dataset, transforms=None, **kwargs):
self.dataset = dataset
self.root = "./data"
path = self.root
self.image_ids = []
if transforms is None:
transforms = T.Compose(
[
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
],
)
self.transforms = transforms
tmp = json.loads(
urllib.request.urlopen(
"https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json",
).read(),
)
wdnids = {int(k): v[0] for k, v in tmp.items()}
self.wdnids = wdnids
idx_to_cls = {int(k): v[1] for k, v in tmp.items()}
# super().__init__(
# path=path,
# transforms=transforms,
# idx_to_cls=idx_to_cls,
# **kwargs,
# )
self.idx_to_cls = idx_to_cls
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
# print(item)
img, label = item["image"], item["label"]
if img.mode != "RGB":
img = img.convert("RGB")
# img = Image.open(path)
if self.transforms:
img = self.transforms(img)
return idx, img, label