xmutly's picture
Upload 294 files
e1aaaac verified
# Code from https://github.com/SsnL/dataset-distillation/blob/master/datasets/pascal_voc.py , thanks to the authors
"""Dataset setting and data loader for PASCAL VOC 2007 as a classification task.
Modified from
https://github.com/Cadene/pretrained-models.pytorch/blob/56aa8c921819d14fb36d7248ab71e191b37cb146/pretrainedmodels/datasets/voc.py
"""
import os
import os.path
import tarfile
import xml.etree.ElementTree as ET
import torch.utils.data as data
import torchvision
from PIL import Image
from urllib.parse import urlparse
import torch
object_categories = ['aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
category_to_idx = {c: i for i, c in enumerate(object_categories)}
urls = {
'devkit': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCdevkit_08-Jun-2007.tar',
'trainval_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
'test_images_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar',
'test_anno_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtestnoimgs_06-Nov-2007.tar',
}
def download_url(url, path):
root, filename = os.path.split(path)
torchvision.datasets.utils.download_url(url, root=root, filename=filename, md5=None)
def download_voc2007(root):
path_devkit = os.path.join(root, 'VOCdevkit')
path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages')
tmpdir = os.path.join(root, 'tmp')
# create directory
if not os.path.exists(root):
os.makedirs(root)
if not os.path.exists(path_devkit):
if not os.path.exists(tmpdir):
os.makedirs(tmpdir)
parts = urlparse(urls['devkit'])
filename = os.path.basename(parts.path)
cached_file = os.path.join(tmpdir, filename)
if not os.path.exists(cached_file):
download_url(urls['devkit'], cached_file)
# extract file
print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root))
cwd = os.getcwd()
tar = tarfile.open(cached_file, "r")
os.chdir(root)
tar.extractall()
tar.close()
os.chdir(cwd)
print('[dataset] Done!')
# train/val images/annotations
if not os.path.exists(path_images):
# download train/val images/annotations
parts = urlparse(urls['trainval_2007'])
filename = os.path.basename(parts.path)
cached_file = os.path.join(tmpdir, filename)
if not os.path.exists(cached_file):
download_url(urls['trainval_2007'], cached_file)
# extract file
print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root))
cwd = os.getcwd()
tar = tarfile.open(cached_file, "r")
os.chdir(root)
tar.extractall()
tar.close()
os.chdir(cwd)
print('[dataset] Done!')
# test annotations
test_anno = os.path.join(path_devkit, 'VOC2007/ImageSets/Main/aeroplane_test.txt')
if not os.path.exists(test_anno):
# download test annotations
parts = urlparse(urls['test_images_2007'])
filename = os.path.basename(parts.path)
cached_file = os.path.join(tmpdir, filename)
if not os.path.exists(cached_file):
download_url(urls['test_images_2007'], cached_file)
# extract file
print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root))
cwd = os.getcwd()
tar = tarfile.open(cached_file, "r")
os.chdir(root)
tar.extractall()
tar.close()
os.chdir(cwd)
print('[dataset] Done!')
# test images
test_image = os.path.join(path_devkit, 'VOC2007/JPEGImages/000001.jpg')
if not os.path.exists(test_image):
# download test images
parts = urlparse(urls['test_anno_2007'])
filename = os.path.basename(parts.path)
cached_file = os.path.join(tmpdir, filename)
if not os.path.exists(cached_file):
download_url(urls['test_anno_2007'], cached_file)
# extract file
print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root))
cwd = os.getcwd()
tar = tarfile.open(cached_file, "r")
os.chdir(root)
tar.extractall()
tar.close()
os.chdir(cwd)
print('[dataset] Done!')
def read_split(root, dataset, split):
base_path = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main')
filename = os.path.join(base_path, object_categories[0] + '_' + split + '.txt')
with open(filename, 'r') as f:
paths = []
for line in f.readlines():
line = line.strip().split()
if len(line) > 0:
assert len(line) == 2
paths.append(line[0])
return tuple(paths)
def read_bndbox(root, dataset, paths):
xml_base = os.path.join(root, 'VOCdevkit', dataset, 'Annotations')
instances = []
for path in paths:
xml = ET.parse(os.path.join(xml_base, path + '.xml'))
for obj in xml.findall('object'):
c = obj[0]
assert c.tag == 'name', c.tag
c = category_to_idx[c.text]
bndbox = obj.find('bndbox')
xmin = int(bndbox[0].text) # left
ymin = int(bndbox[1].text) # top
xmax = int(bndbox[2].text) # right
ymax = int(bndbox[3].text) # bottom
instances.append((path, (xmin, ymin, xmax, ymax), c))
return instances
class PASCALVoc2007(data.Dataset):
"""
Multi-label classification problem for voc2007
labels are of one hot of shape (C,), denoting the presence/absence
of each class in each image, where C is the number of classes.
"""
def __init__(self, root, set, transform=None, download=False, target_transform=None):
self.root = root
self.path_devkit = os.path.join(root, 'VOCdevkit')
self.path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages')
self.transform = transform
self.target_transform = target_transform
# download dataset
if download:
download_voc2007(self.root)
paths = read_split(self.root, 'VOC2007', set)
bndboxes = read_bndbox(self.root, 'VOC2007', paths)
labels = torch.zeros(len(paths), len(object_categories))
path_index = {}
for i, p in enumerate(paths):
path_index[p] = i
for path, bbox, c in bndboxes:
labels[path_index[path], c] = 1
self.labels = labels
self.classes = object_categories
self.paths = paths
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(os.path.join(self.path_images, path + '.jpg')).convert('RGB')
target = self.labels[index]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.paths)
class PASCALVoc2007Cropped(data.Dataset):
"""
voc2007 is originally object detection and multi-label.
In this version, we just convert it to single-label per image classification
problem by looping over bounding boxes in the dataset and cropping the relevant
object.
"""
def __init__(self, root, set, transform=None, download=False, target_transform=None):
self.root = root
self.path_devkit = os.path.join(root, 'VOCdevkit')
self.path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages')
self.transform = transform
self.target_transform = target_transform
# download dataset
if download:
download_voc2007(self.root)
paths = read_split(self.root, 'VOC2007', set)
self.bndboxes = read_bndbox(self.root, 'VOC2007', paths)
self.classes = object_categories
print('[dataset] VOC 2007 classification set=%s number of classes=%d number of bndboxes=%d' % (
set, len(self.classes), len(self.bndboxes)))
def __getitem__(self, index):
path, crop, target = self.bndboxes[index]
img = Image.open(os.path.join(self.path_images, path + '.jpg')).convert('RGB')
img = img.crop(crop)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.bndboxes)