| import os |
| import numpy as np |
| from PIL import Image |
| import torch.utils.data as data |
|
|
| DATA_ROOTS = 'data/CUBirds' |
|
|
| |
| |
|
|
| class CUBirds(data.Dataset): |
| def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None): |
| super().__init__() |
| self.root = root |
| self.train = train |
| self.image_transforms = image_transforms |
| paths, labels = self.load_images() |
| self.paths, self.labels = paths, labels |
|
|
| def load_images(self): |
| image_info_path = os.path.join(self.root, 'images.txt') |
| with open(image_info_path, 'r') as f: |
| image_info = [line.split('\n')[0].split(' ', 1) for line in f.readlines()] |
| image_info = dict(image_info) |
|
|
| |
| label_info_path = os.path.join(self.root, 'image_class_labels.txt') |
| with open(label_info_path, 'r') as f: |
| label_info = [line.split('\n')[0].split(' ', 1) for line in f.readlines()] |
| label_info = dict(label_info) |
|
|
| |
| train_test_info_path = os.path.join(self.root, 'train_test_split.txt') |
| with open(train_test_info_path, 'r') as f: |
| train_test_info = [line.split('\n')[0].split(' ', 1) for line in f.readlines()] |
| train_test_info = dict(train_test_info) |
|
|
| all_paths, all_labels = [], [] |
| for index, image_path in image_info.items(): |
| label = label_info[index] |
| split = int(train_test_info[index]) |
| if self.train: |
| if split == 1: |
| all_paths.append(image_path) |
| all_labels.append(label) |
| else: |
| if split == 0: |
| all_paths.append(image_path) |
| all_labels.append(label) |
| return all_paths, all_labels |
|
|
| def __len__(self): |
| return len(self.paths) |
|
|
| def __getitem__(self, index): |
| path = os.path.join(self.root, 'images', self.paths[index]) |
| label = int(self.labels[index]) - 1 |
| image = Image.open(path).convert(mode='RGB') |
| if self.image_transforms: |
| image = self.image_transforms(image) |
| return image, label |