|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import io |
|
|
import os |
|
|
import time |
|
|
import torch.distributed as dist |
|
|
import torch.utils.data as data |
|
|
from PIL import Image |
|
|
|
|
|
from .zipreader import is_zip_path, ZipReader |
|
|
|
|
|
|
|
|
def has_file_allowed_extension(filename, extensions): |
|
|
"""Checks if a file is an allowed extension. |
|
|
Args: |
|
|
filename (string): path to a file |
|
|
Returns: |
|
|
bool: True if the filename ends with a known image extension |
|
|
""" |
|
|
filename_lower = filename.lower() |
|
|
return any(filename_lower.endswith(ext) for ext in extensions) |
|
|
|
|
|
|
|
|
def find_classes(dir): |
|
|
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] |
|
|
classes.sort() |
|
|
class_to_idx = {classes[i]: i for i in range(len(classes))} |
|
|
return classes, class_to_idx |
|
|
|
|
|
|
|
|
def make_dataset(dir, class_to_idx, extensions): |
|
|
images = [] |
|
|
dir = os.path.expanduser(dir) |
|
|
for target in sorted(os.listdir(dir)): |
|
|
d = os.path.join(dir, target) |
|
|
if not os.path.isdir(d): |
|
|
continue |
|
|
|
|
|
for root, _, fnames in sorted(os.walk(d)): |
|
|
for fname in sorted(fnames): |
|
|
if has_file_allowed_extension(fname, extensions): |
|
|
path = os.path.join(root, fname) |
|
|
item = (path, class_to_idx[target]) |
|
|
images.append(item) |
|
|
|
|
|
return images |
|
|
|
|
|
|
|
|
def make_dataset_with_ann(ann_file, img_prefix, extensions): |
|
|
images = [] |
|
|
with open(ann_file, "r") as f: |
|
|
contents = f.readlines() |
|
|
for line_str in contents: |
|
|
path_contents = [c for c in line_str.split('\t')] |
|
|
im_file_name = path_contents[0] |
|
|
class_index = int(path_contents[1]) |
|
|
|
|
|
assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions |
|
|
item = (os.path.join(img_prefix, im_file_name), class_index) |
|
|
|
|
|
images.append(item) |
|
|
|
|
|
return images |
|
|
|
|
|
|
|
|
class DatasetFolder(data.Dataset): |
|
|
"""A generic data loader where the samples are arranged in this way: :: |
|
|
root/class_x/xxx.ext |
|
|
root/class_x/xxy.ext |
|
|
root/class_x/xxz.ext |
|
|
root/class_y/123.ext |
|
|
root/class_y/nsdf3.ext |
|
|
root/class_y/asd932_.ext |
|
|
Args: |
|
|
root (string): Root directory path. |
|
|
loader (callable): A function to load a sample given its path. |
|
|
extensions (list[string]): A list of allowed extensions. |
|
|
transform (callable, optional): A function/transform that takes in |
|
|
a sample and returns a transformed version. |
|
|
E.g, ``transforms.RandomCrop`` for images. |
|
|
target_transform (callable, optional): A function/transform that takes |
|
|
in the target and transforms it. |
|
|
Attributes: |
|
|
samples (list): List of (sample path, class_index) tuples |
|
|
""" |
|
|
|
|
|
def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, |
|
|
cache_mode="no"): |
|
|
|
|
|
if ann_file == '': |
|
|
_, class_to_idx = find_classes(root) |
|
|
samples = make_dataset(root, class_to_idx, extensions) |
|
|
|
|
|
else: |
|
|
samples = make_dataset_with_ann(os.path.join(root, ann_file), |
|
|
os.path.join(root, img_prefix), |
|
|
extensions) |
|
|
|
|
|
if len(samples) == 0: |
|
|
raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + |
|
|
"Supported extensions are: " + ",".join(extensions))) |
|
|
|
|
|
self.root = root |
|
|
self.loader = loader |
|
|
self.extensions = extensions |
|
|
|
|
|
self.samples = samples |
|
|
self.labels = [y_1k for _, y_1k in samples] |
|
|
self.classes = list(set(self.labels)) |
|
|
|
|
|
self.transform = transform |
|
|
self.target_transform = target_transform |
|
|
|
|
|
self.cache_mode = cache_mode |
|
|
if self.cache_mode != "no": |
|
|
self.init_cache() |
|
|
|
|
|
def init_cache(self): |
|
|
assert self.cache_mode in ["part", "full"] |
|
|
n_sample = len(self.samples) |
|
|
global_rank = dist.get_rank() |
|
|
world_size = dist.get_world_size() |
|
|
|
|
|
samples_bytes = [None for _ in range(n_sample)] |
|
|
start_time = time.time() |
|
|
for index in range(n_sample): |
|
|
if index % (n_sample // 10) == 0: |
|
|
t = time.time() - start_time |
|
|
print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') |
|
|
start_time = time.time() |
|
|
path, target = self.samples[index] |
|
|
if self.cache_mode == "full": |
|
|
samples_bytes[index] = (ZipReader.read(path), target) |
|
|
elif self.cache_mode == "part" and index % world_size == global_rank: |
|
|
samples_bytes[index] = (ZipReader.read(path), target) |
|
|
else: |
|
|
samples_bytes[index] = (path, target) |
|
|
self.samples = samples_bytes |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
tuple: (sample, target) where target is class_index of the target class. |
|
|
""" |
|
|
path, target = self.samples[index] |
|
|
sample = self.loader(path) |
|
|
if self.transform is not None: |
|
|
sample = self.transform(sample) |
|
|
if self.target_transform is not None: |
|
|
target = self.target_transform(target) |
|
|
|
|
|
return sample, target |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def __repr__(self): |
|
|
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' |
|
|
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) |
|
|
fmt_str += ' Root Location: {}\n'.format(self.root) |
|
|
tmp = ' Transforms (if any): ' |
|
|
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) |
|
|
tmp = ' Target Transforms (if any): ' |
|
|
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) |
|
|
return fmt_str |
|
|
|
|
|
|
|
|
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] |
|
|
|
|
|
|
|
|
def pil_loader(path): |
|
|
|
|
|
if isinstance(path, bytes): |
|
|
img = Image.open(io.BytesIO(path)) |
|
|
elif is_zip_path(path): |
|
|
data = ZipReader.read(path) |
|
|
img = Image.open(io.BytesIO(data)) |
|
|
else: |
|
|
with open(path, 'rb') as f: |
|
|
img = Image.open(f) |
|
|
return img.convert('RGB') |
|
|
|
|
|
|
|
|
def accimage_loader(path): |
|
|
import accimage |
|
|
try: |
|
|
return accimage.Image(path) |
|
|
except IOError: |
|
|
|
|
|
return pil_loader(path) |
|
|
|
|
|
|
|
|
def default_img_loader(path): |
|
|
from torchvision import get_image_backend |
|
|
if get_image_backend() == 'accimage': |
|
|
return accimage_loader(path) |
|
|
else: |
|
|
return pil_loader(path) |
|
|
|
|
|
|
|
|
class CachedImageFolder(DatasetFolder): |
|
|
"""A generic data loader where the images are arranged in this way: :: |
|
|
root/dog/xxx.png |
|
|
root/dog/xxy.png |
|
|
root/dog/xxz.png |
|
|
root/cat/123.png |
|
|
root/cat/nsdf3.png |
|
|
root/cat/asd932_.png |
|
|
Args: |
|
|
root (string): Root directory path. |
|
|
transform (callable, optional): A function/transform that takes in an PIL image |
|
|
and returns a transformed version. E.g, ``transforms.RandomCrop`` |
|
|
target_transform (callable, optional): A function/transform that takes in the |
|
|
target and transforms it. |
|
|
loader (callable, optional): A function to load an image given its path. |
|
|
Attributes: |
|
|
imgs (list): List of (image path, class_index) tuples |
|
|
""" |
|
|
|
|
|
def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, |
|
|
loader=default_img_loader, cache_mode="no"): |
|
|
super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, |
|
|
ann_file=ann_file, img_prefix=img_prefix, |
|
|
transform=transform, target_transform=target_transform, |
|
|
cache_mode=cache_mode) |
|
|
self.imgs = self.samples |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
tuple: (image, target) where target is class_index of the target class. |
|
|
""" |
|
|
path, target = self.samples[index] |
|
|
image = self.loader(path) |
|
|
if self.transform is not None: |
|
|
img = self.transform(image) |
|
|
else: |
|
|
img = image |
|
|
if self.target_transform is not None: |
|
|
target = self.target_transform(target) |
|
|
|
|
|
return img, target |
|
|
|