| import json |
| import urllib.request |
|
|
| import torchvision.transforms as T |
| from cods.classif.data import ClassificationDataset |
| from PIL import Image |
|
|
|
|
| class DatasetWrapper(ClassificationDataset): |
| def __init__(self, dataset, transforms=None, **kwargs): |
| self.dataset = dataset |
| self.root = "./data" |
| path = self.root |
| self.image_ids = [] |
|
|
| if transforms is None: |
| transforms = T.Compose( |
| [ |
| T.Resize(256), |
| T.CenterCrop(224), |
| T.ToTensor(), |
| T.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225], |
| ), |
| ], |
| ) |
| self.transforms = transforms |
| tmp = json.loads( |
| urllib.request.urlopen( |
| "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json", |
| ).read(), |
| ) |
| wdnids = {int(k): v[0] for k, v in tmp.items()} |
| self.wdnids = wdnids |
| idx_to_cls = {int(k): v[1] for k, v in tmp.items()} |
| |
| |
| |
| |
| |
| |
| self.idx_to_cls = idx_to_cls |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| item = self.dataset[idx] |
| |
| img, label = item["image"], item["label"] |
|
|
| if img.mode != "RGB": |
| img = img.convert("RGB") |
| |
| if self.transforms: |
| img = self.transforms(img) |
| return idx, img, label |
|
|