| import os |
| import copy |
| import numpy as np |
| from PIL import Image |
| from os.path import join |
| from itertools import chain |
| from collections import defaultdict |
|
|
| import torch |
| import torch.utils.data as data |
| from torchvision import transforms |
|
|
| DATA_ROOTS = 'data/DTD' |
|
|
| |
| |
|
|
| class DTD(data.Dataset): |
| def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None): |
| super().__init__() |
| self.root = root |
| self.train = train |
| self.image_transforms = image_transforms |
| paths, labels = self.load_images() |
| self.paths, self.labels = paths, labels |
|
|
| def load_images(self): |
| if self.train: |
| train_info_path = os.path.join(self.root, 'labels', 'train1.txt') |
| with open(train_info_path, 'r') as f: |
| train_info = [line.split('\n')[0] for line in f.readlines()] |
|
|
| val_info_path = os.path.join(self.root, 'labels', 'val1.txt') |
| with open(val_info_path, 'r') as f: |
| val_info = [line.split('\n')[0] for line in f.readlines()] |
| split_info = train_info + val_info |
|
|
| else: |
| test_info_path = os.path.join(self.root, 'labels', 'test1.txt') |
| with open(test_info_path, 'r') as f: |
| split_info = [line.split('\n')[0] for line in f.readlines()] |
|
|
| |
| categories = [] |
| for row in split_info: |
| image_path = row |
| category = image_path.split('/')[0] |
| categories.append(category) |
| categories = sorted(list(set(categories))) |
|
|
| all_paths, all_labels = [], [] |
| for row in split_info: |
| image_path = row |
| category = image_path.split('/')[0] |
| label = categories.index(category) |
| all_paths.append(join(self.root, 'images', image_path)) |
| all_labels.append(label) |
| return all_paths, all_labels |
|
|
| def __len__(self): |
| return len(self.paths) |
|
|
| def __getitem__(self, index): |
| path = self.paths[index] |
| label = self.labels[index] |
| image = Image.open(path).convert(mode='RGB') |
| if self.image_transforms: |
| image = self.image_transforms(image) |
| return image, label |