Image Classification
English
TTA
ReservoirTTA / robustbench /loaders.py
GuillaumeVray
Uploading files
02ba886
"""
This file is based on the code from https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py.
Adapted from: https://github.com/RobustBench/robustbench/blob/master/robustbench/loaders.py
"""
from torchvision.datasets.vision import VisionDataset
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
import os
import os.path
import sys
import json
import numpy as np
def make_custom_dataset(root, path_imgs, cls_dict):
with open(path_imgs, 'r') as f:
fnames = f.readlines()
with open(cls_dict, 'r') as f:
class_to_idx = json.load(f)
images = [(os.path.join(root, c.split('\n')[0]), class_to_idx[c.split(os.sep)[0]]) for c in fnames]
return images
class CustomDatasetFolder(VisionDataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid_file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None):
super(CustomDatasetFolder, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
classes, class_to_idx = self._find_classes(self.root)
samples = make_custom_dataset(self.root, 'robustbench/data/imagenet_test_image_ids.txt',
'robustbench/data/imagenet_class_to_id_map.json')
if len(samples) == 0:
raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
else:
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
data = self.samples[index]
if len(data) == 3:
path, target, domain = data
else:
path, target = data
domain = path.split(os.sep)[-4]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target, domain, path
def __len__(self):
return len(self.samples)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
class CustomImageFolder(CustomDatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid_file (used to check of corrupt files)
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader, is_valid_file=None):
super(CustomImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file)
self.imgs = self.samples
class CustomCifarDataset(data.Dataset):
def __init__(self, samples, transform=None):
super(CustomCifarDataset, self).__init__()
self.samples = samples
self.transform = transform
def __getitem__(self, index):
data = self.samples[index]
img, label, domain = self.samples[index]
if self.transform is not None:
img = Image.fromarray(np.uint8(img * 255.)).convert('RGB')
img = self.transform(img)
else:
img = torch.tensor(img.transpose((2, 0, 1)))
return img, torch.tensor(label), domain
def __len__(self):
return len(self.samples)
if __name__ == '__main__':
data_dir = '/home/scratch/datasets/imagenet/val'
imagenet = CustomImageFolder(data_dir, transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]))
torch.manual_seed(0)
test_loader = data.DataLoader(imagenet, batch_size=5000, shuffle=True, num_workers=30)
x, y, path = next(iter(test_loader))
with open('path_imgs_2.txt', 'w') as f:
f.write('\n'.join(path))
f.flush()