| |
| """Dataset setting and data loader for PASCAL VOC 2007 as a classification task. |
| |
| Modified from |
| https://github.com/Cadene/pretrained-models.pytorch/blob/56aa8c921819d14fb36d7248ab71e191b37cb146/pretrainedmodels/datasets/voc.py |
| """ |
|
|
| import os |
| import os.path |
| import tarfile |
| import xml.etree.ElementTree as ET |
|
|
| import torch.utils.data as data |
| import torchvision |
| from PIL import Image |
| from urllib.parse import urlparse |
| import torch |
|
|
| object_categories = ['aeroplane', 'bicycle', 'bird', 'boat', |
| 'bottle', 'bus', 'car', 'cat', 'chair', |
| 'cow', 'diningtable', 'dog', 'horse', |
| 'motorbike', 'person', 'pottedplant', |
| 'sheep', 'sofa', 'train', 'tvmonitor'] |
|
|
| category_to_idx = {c: i for i, c in enumerate(object_categories)} |
|
|
| urls = { |
| 'devkit': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCdevkit_08-Jun-2007.tar', |
| 'trainval_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', |
| 'test_images_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar', |
| 'test_anno_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtestnoimgs_06-Nov-2007.tar', |
| } |
|
|
|
|
| def download_url(url, path): |
| root, filename = os.path.split(path) |
| torchvision.datasets.utils.download_url(url, root=root, filename=filename, md5=None) |
|
|
|
|
| def download_voc2007(root): |
| path_devkit = os.path.join(root, 'VOCdevkit') |
| path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages') |
| tmpdir = os.path.join(root, 'tmp') |
|
|
| |
| if not os.path.exists(root): |
| os.makedirs(root) |
|
|
| if not os.path.exists(path_devkit): |
|
|
| if not os.path.exists(tmpdir): |
| os.makedirs(tmpdir) |
|
|
| parts = urlparse(urls['devkit']) |
| filename = os.path.basename(parts.path) |
| cached_file = os.path.join(tmpdir, filename) |
|
|
| if not os.path.exists(cached_file): |
| download_url(urls['devkit'], cached_file) |
|
|
| |
| print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) |
| cwd = os.getcwd() |
| tar = tarfile.open(cached_file, "r") |
| os.chdir(root) |
| tar.extractall() |
| tar.close() |
| os.chdir(cwd) |
| print('[dataset] Done!') |
|
|
| |
| if not os.path.exists(path_images): |
|
|
| |
| parts = urlparse(urls['trainval_2007']) |
| filename = os.path.basename(parts.path) |
| cached_file = os.path.join(tmpdir, filename) |
|
|
| if not os.path.exists(cached_file): |
| download_url(urls['trainval_2007'], cached_file) |
|
|
| |
| print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) |
| cwd = os.getcwd() |
| tar = tarfile.open(cached_file, "r") |
| os.chdir(root) |
| tar.extractall() |
| tar.close() |
| os.chdir(cwd) |
| print('[dataset] Done!') |
|
|
| |
| test_anno = os.path.join(path_devkit, 'VOC2007/ImageSets/Main/aeroplane_test.txt') |
| if not os.path.exists(test_anno): |
|
|
| |
| parts = urlparse(urls['test_images_2007']) |
| filename = os.path.basename(parts.path) |
| cached_file = os.path.join(tmpdir, filename) |
|
|
| if not os.path.exists(cached_file): |
| download_url(urls['test_images_2007'], cached_file) |
|
|
| |
| print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) |
| cwd = os.getcwd() |
| tar = tarfile.open(cached_file, "r") |
| os.chdir(root) |
| tar.extractall() |
| tar.close() |
| os.chdir(cwd) |
| print('[dataset] Done!') |
|
|
| |
| test_image = os.path.join(path_devkit, 'VOC2007/JPEGImages/000001.jpg') |
| if not os.path.exists(test_image): |
|
|
| |
| parts = urlparse(urls['test_anno_2007']) |
| filename = os.path.basename(parts.path) |
| cached_file = os.path.join(tmpdir, filename) |
|
|
| if not os.path.exists(cached_file): |
| download_url(urls['test_anno_2007'], cached_file) |
|
|
| |
| print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) |
| cwd = os.getcwd() |
| tar = tarfile.open(cached_file, "r") |
| os.chdir(root) |
| tar.extractall() |
| tar.close() |
| os.chdir(cwd) |
| print('[dataset] Done!') |
|
|
|
|
| def read_split(root, dataset, split): |
| base_path = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main') |
| filename = os.path.join(base_path, object_categories[0] + '_' + split + '.txt') |
|
|
| with open(filename, 'r') as f: |
| paths = [] |
| for line in f.readlines(): |
| line = line.strip().split() |
| if len(line) > 0: |
| assert len(line) == 2 |
| paths.append(line[0]) |
|
|
| return tuple(paths) |
|
|
|
|
| def read_bndbox(root, dataset, paths): |
| xml_base = os.path.join(root, 'VOCdevkit', dataset, 'Annotations') |
| instances = [] |
| for path in paths: |
| xml = ET.parse(os.path.join(xml_base, path + '.xml')) |
| for obj in xml.findall('object'): |
| c = obj[0] |
| assert c.tag == 'name', c.tag |
| c = category_to_idx[c.text] |
| bndbox = obj.find('bndbox') |
| xmin = int(bndbox[0].text) |
| ymin = int(bndbox[1].text) |
| xmax = int(bndbox[2].text) |
| ymax = int(bndbox[3].text) |
| instances.append((path, (xmin, ymin, xmax, ymax), c)) |
| return instances |
|
|
|
|
| class PASCALVoc2007(data.Dataset): |
| """ |
| Multi-label classification problem for voc2007 |
| labels are of one hot of shape (C,), denoting the presence/absence |
| of each class in each image, where C is the number of classes. |
| """ |
| def __init__(self, root, set, transform=None, download=False, target_transform=None): |
| self.root = root |
| self.path_devkit = os.path.join(root, 'VOCdevkit') |
| self.path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages') |
| self.transform = transform |
| self.target_transform = target_transform |
|
|
| |
| if download: |
| download_voc2007(self.root) |
|
|
| paths = read_split(self.root, 'VOC2007', set) |
| bndboxes = read_bndbox(self.root, 'VOC2007', paths) |
| labels = torch.zeros(len(paths), len(object_categories)) |
| path_index = {} |
| for i, p in enumerate(paths): |
| path_index[p] = i |
| for path, bbox, c in bndboxes: |
| labels[path_index[path], c] = 1 |
| self.labels = labels |
| self.classes = object_categories |
| self.paths = paths |
|
|
| def __getitem__(self, index): |
| path = self.paths[index] |
| img = Image.open(os.path.join(self.path_images, path + '.jpg')).convert('RGB') |
| target = self.labels[index] |
| if self.transform is not None: |
| img = self.transform(img) |
| if self.target_transform is not None: |
| target = self.target_transform(target) |
| return img, target |
|
|
| def __len__(self): |
| return len(self.paths) |
|
|
| class PASCALVoc2007Cropped(data.Dataset): |
| """ |
| voc2007 is originally object detection and multi-label. |
| In this version, we just convert it to single-label per image classification |
| problem by looping over bounding boxes in the dataset and cropping the relevant |
| object. |
| """ |
| def __init__(self, root, set, transform=None, download=False, target_transform=None): |
| self.root = root |
| self.path_devkit = os.path.join(root, 'VOCdevkit') |
| self.path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages') |
| self.transform = transform |
| self.target_transform = target_transform |
| |
| |
| if download: |
| download_voc2007(self.root) |
|
|
| paths = read_split(self.root, 'VOC2007', set) |
| self.bndboxes = read_bndbox(self.root, 'VOC2007', paths) |
| self.classes = object_categories |
|
|
| print('[dataset] VOC 2007 classification set=%s number of classes=%d number of bndboxes=%d' % ( |
| set, len(self.classes), len(self.bndboxes))) |
|
|
| def __getitem__(self, index): |
| path, crop, target = self.bndboxes[index] |
| img = Image.open(os.path.join(self.path_images, path + '.jpg')).convert('RGB') |
| img = img.crop(crop) |
| if self.transform is not None: |
| img = self.transform(img) |
| if self.target_transform is not None: |
| target = self.target_transform(target) |
| return img, target |
|
|
| def __len__(self): |
| return len(self.bndboxes) |
|
|