|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import io |
|
|
import json |
|
|
import logging |
|
|
import math |
|
|
import os |
|
|
import os.path as osp |
|
|
import re |
|
|
import time |
|
|
from abc import abstractmethod |
|
|
|
|
|
import mmcv |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.utils.data as data |
|
|
from mmcv.fileio import FileClient |
|
|
from PIL import Image |
|
|
from tqdm import tqdm, trange |
|
|
|
|
|
from .zipreader import ZipReader, is_zip_path |
|
|
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
|
|
_ERROR_RETRY = 50 |
|
|
|
|
|
|
|
|
def has_file_allowed_extension(filename, extensions): |
|
|
"""Checks if a file is an allowed extension. |
|
|
|
|
|
Args: |
|
|
filename (string): path to a file |
|
|
Returns: |
|
|
bool: True if the filename ends with a known image extension |
|
|
""" |
|
|
filename_lower = filename.lower() |
|
|
return any(filename_lower.endswith(ext) for ext in extensions) |
|
|
|
|
|
|
|
|
def find_classes(dir): |
|
|
classes = [ |
|
|
d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) |
|
|
] |
|
|
classes.sort() |
|
|
class_to_idx = {classes[i]: i for i in range(len(classes))} |
|
|
return classes, class_to_idx |
|
|
|
|
|
|
|
|
def make_dataset(dir, class_to_idx, extensions): |
|
|
images = [] |
|
|
dir = os.path.expanduser(dir) |
|
|
for target in sorted(os.listdir(dir)): |
|
|
d = os.path.join(dir, target) |
|
|
if not os.path.isdir(d): |
|
|
continue |
|
|
for root, _, fnames in sorted(os.walk(d)): |
|
|
for fname in sorted(fnames): |
|
|
if has_file_allowed_extension(fname, extensions): |
|
|
path = os.path.join(root, fname) |
|
|
item = (path, class_to_idx[target]) |
|
|
images.append(item) |
|
|
|
|
|
return images |
|
|
|
|
|
|
|
|
def make_dataset_with_ann(ann_file, img_prefix, extensions): |
|
|
images = [] |
|
|
with open(ann_file, 'r') as f: |
|
|
contents = f.readlines() |
|
|
for line_str in contents: |
|
|
path_contents = [c for c in line_str.split('\t')] |
|
|
im_file_name = path_contents[0] |
|
|
class_index = int(path_contents[1]) |
|
|
assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions |
|
|
item = (os.path.join(img_prefix, im_file_name), class_index) |
|
|
images.append(item) |
|
|
|
|
|
return images |
|
|
|
|
|
|
|
|
class DatasetFolder(data.Dataset): |
|
|
"""A generic data loader where the samples are arranged in this way: :: |
|
|
|
|
|
root/class_x/xxx.ext |
|
|
root/class_x/xxy.ext |
|
|
root/class_x/xxz.ext |
|
|
root/class_y/123.ext |
|
|
root/class_y/nsdf3.ext |
|
|
root/class_y/asd932_.ext |
|
|
Args: |
|
|
root (string): Root directory path. |
|
|
loader (callable): A function to load a sample given its path. |
|
|
extensions (list[string]): A list of allowed extensions. |
|
|
transform (callable, optional): A function/transform that takes in |
|
|
a sample and returns a transformed version. |
|
|
E.g, ``transforms.RandomCrop`` for images. |
|
|
target_transform (callable, optional): A function/transform that takes |
|
|
in the target and transforms it. |
|
|
Attributes: |
|
|
samples (list): List of (sample path, class_index) tuples |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
root, |
|
|
loader, |
|
|
extensions, |
|
|
ann_file='', |
|
|
img_prefix='', |
|
|
transform=None, |
|
|
target_transform=None, |
|
|
cache_mode='no'): |
|
|
|
|
|
if ann_file == '': |
|
|
_, class_to_idx = find_classes(root) |
|
|
samples = make_dataset(root, class_to_idx, extensions) |
|
|
|
|
|
else: |
|
|
samples = make_dataset_with_ann(os.path.join(root, ann_file), |
|
|
os.path.join(root, img_prefix), |
|
|
extensions) |
|
|
|
|
|
if len(samples) == 0: |
|
|
raise (RuntimeError('Found 0 files in subfolders of: ' + root + |
|
|
'\n' + 'Supported extensions are: ' + |
|
|
','.join(extensions))) |
|
|
|
|
|
self.root = root |
|
|
self.loader = loader |
|
|
self.extensions = extensions |
|
|
|
|
|
self.samples = samples |
|
|
self.labels = [y_1k for _, y_1k in samples] |
|
|
self.classes = list(set(self.labels)) |
|
|
|
|
|
self.transform = transform |
|
|
self.target_transform = target_transform |
|
|
|
|
|
self.cache_mode = cache_mode |
|
|
if self.cache_mode != 'no': |
|
|
self.init_cache() |
|
|
|
|
|
def init_cache(self): |
|
|
assert self.cache_mode in ['part', 'full'] |
|
|
n_sample = len(self.samples) |
|
|
global_rank = dist.get_rank() |
|
|
world_size = dist.get_world_size() |
|
|
|
|
|
samples_bytes = [None for _ in range(n_sample)] |
|
|
start_time = time.time() |
|
|
for index in range(n_sample): |
|
|
if index % (n_sample // 10) == 0: |
|
|
t = time.time() - start_time |
|
|
print( |
|
|
f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block' |
|
|
) |
|
|
start_time = time.time() |
|
|
path, target = self.samples[index] |
|
|
if self.cache_mode == 'full': |
|
|
samples_bytes[index] = (ZipReader.read(path), target) |
|
|
elif self.cache_mode == 'part' and index % world_size == global_rank: |
|
|
samples_bytes[index] = (ZipReader.read(path), target) |
|
|
else: |
|
|
samples_bytes[index] = (path, target) |
|
|
self.samples = samples_bytes |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
tuple: (sample, target) where target is class_index of the target class. |
|
|
""" |
|
|
path, target = self.samples[index] |
|
|
sample = self.loader(path) |
|
|
if self.transform is not None: |
|
|
sample = self.transform(sample) |
|
|
if self.target_transform is not None: |
|
|
target = self.target_transform(target) |
|
|
|
|
|
return sample, target |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def __repr__(self): |
|
|
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' |
|
|
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) |
|
|
fmt_str += ' Root Location: {}\n'.format(self.root) |
|
|
tmp = ' Transforms (if any): ' |
|
|
fmt_str += '{0}{1}\n'.format( |
|
|
tmp, |
|
|
self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) |
|
|
tmp = ' Target Transforms (if any): ' |
|
|
fmt_str += '{0}{1}'.format( |
|
|
tmp, |
|
|
self.target_transform.__repr__().replace('\n', |
|
|
'\n' + ' ' * len(tmp))) |
|
|
|
|
|
return fmt_str |
|
|
|
|
|
|
|
|
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] |
|
|
|
|
|
|
|
|
def pil_loader(path): |
|
|
|
|
|
if isinstance(path, bytes): |
|
|
img = Image.open(io.BytesIO(path)) |
|
|
elif is_zip_path(path): |
|
|
data = ZipReader.read(path) |
|
|
img = Image.open(io.BytesIO(data)) |
|
|
else: |
|
|
with open(path, 'rb') as f: |
|
|
img = Image.open(f) |
|
|
return img.convert('RGB') |
|
|
|
|
|
return img.convert('RGB') |
|
|
|
|
|
|
|
|
def accimage_loader(path): |
|
|
import accimage |
|
|
try: |
|
|
return accimage.Image(path) |
|
|
except IOError: |
|
|
|
|
|
return pil_loader(path) |
|
|
|
|
|
|
|
|
def default_img_loader(path): |
|
|
from torchvision import get_image_backend |
|
|
if get_image_backend() == 'accimage': |
|
|
return accimage_loader(path) |
|
|
else: |
|
|
return pil_loader(path) |
|
|
|
|
|
|
|
|
class CachedImageFolder(DatasetFolder): |
|
|
"""A generic data loader where the images are arranged in this way: :: |
|
|
|
|
|
root/dog/xxx.png |
|
|
root/dog/xxy.png |
|
|
root/dog/xxz.png |
|
|
root/cat/123.png |
|
|
root/cat/nsdf3.png |
|
|
root/cat/asd932_.png |
|
|
Args: |
|
|
root (string): Root directory path. |
|
|
transform (callable, optional): A function/transform that takes in an PIL image |
|
|
and returns a transformed version. E.g, ``transforms.RandomCrop`` |
|
|
target_transform (callable, optional): A function/transform that takes in the |
|
|
target and transforms it. |
|
|
loader (callable, optional): A function to load an image given its path. |
|
|
Attributes: |
|
|
imgs (list): List of (image path, class_index) tuples |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
root, |
|
|
ann_file='', |
|
|
img_prefix='', |
|
|
transform=None, |
|
|
target_transform=None, |
|
|
loader=default_img_loader, |
|
|
cache_mode='no'): |
|
|
super(CachedImageFolder, |
|
|
self).__init__(root, |
|
|
loader, |
|
|
IMG_EXTENSIONS, |
|
|
ann_file=ann_file, |
|
|
img_prefix=img_prefix, |
|
|
transform=transform, |
|
|
target_transform=target_transform, |
|
|
cache_mode=cache_mode) |
|
|
self.imgs = self.samples |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
tuple: (image, target) where target is class_index of the target class. |
|
|
""" |
|
|
path, target = self.samples[index] |
|
|
image = self.loader(path) |
|
|
if self.transform is not None: |
|
|
img = self.transform(image) |
|
|
else: |
|
|
img = image |
|
|
if self.target_transform is not None: |
|
|
target = self.target_transform(target) |
|
|
|
|
|
return img, target |
|
|
|
|
|
|
|
|
class ImageCephDataset(data.Dataset): |
|
|
|
|
|
def __init__(self, |
|
|
root, |
|
|
split, |
|
|
parser=None, |
|
|
transform=None, |
|
|
target_transform=None, |
|
|
on_memory=False): |
|
|
if '22k' in root: |
|
|
|
|
|
annotation_root = 'meta_data/' |
|
|
else: |
|
|
|
|
|
annotation_root = 'meta_data/' |
|
|
if parser is None or isinstance(parser, str): |
|
|
parser = ParserCephImage(root=root, |
|
|
split=split, |
|
|
annotation_root=annotation_root, |
|
|
on_memory=on_memory) |
|
|
self.parser = parser |
|
|
self.transform = transform |
|
|
self.target_transform = target_transform |
|
|
self._consecutive_errors = 0 |
|
|
|
|
|
def __getitem__(self, index): |
|
|
img, target = self.parser[index] |
|
|
self._consecutive_errors = 0 |
|
|
if self.transform is not None: |
|
|
img = self.transform(img) |
|
|
if target is None: |
|
|
target = -1 |
|
|
elif self.target_transform is not None: |
|
|
target = self.target_transform(target) |
|
|
return img, target |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.parser) |
|
|
|
|
|
def filename(self, index, basename=False, absolute=False): |
|
|
return self.parser.filename(index, basename, absolute) |
|
|
|
|
|
def filenames(self, basename=False, absolute=False): |
|
|
return self.parser.filenames(basename, absolute) |
|
|
|
|
|
|
|
|
class Parser: |
|
|
|
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def _filename(self, index, basename=False, absolute=False): |
|
|
pass |
|
|
|
|
|
def filename(self, index, basename=False, absolute=False): |
|
|
return self._filename(index, basename=basename, absolute=absolute) |
|
|
|
|
|
def filenames(self, basename=False, absolute=False): |
|
|
return [ |
|
|
self._filename(index, basename=basename, absolute=absolute) |
|
|
for index in range(len(self)) |
|
|
] |
|
|
|
|
|
|
|
|
class ParserCephImage(Parser): |
|
|
|
|
|
def __init__(self, |
|
|
root, |
|
|
split, |
|
|
annotation_root, |
|
|
on_memory=False, |
|
|
**kwargs): |
|
|
super().__init__() |
|
|
|
|
|
self.file_client = None |
|
|
self.kwargs = kwargs |
|
|
|
|
|
self.root = root |
|
|
if '22k' in root: |
|
|
self.io_backend = 'petrel' |
|
|
with open(osp.join(annotation_root, '22k_class_to_idx.json'), |
|
|
'r') as f: |
|
|
self.class_to_idx = json.loads(f.read()) |
|
|
with open(osp.join(annotation_root, '22k_label.txt'), 'r') as f: |
|
|
self.samples = f.read().splitlines() |
|
|
else: |
|
|
self.io_backend = 'disk' |
|
|
self.class_to_idx = None |
|
|
with open(osp.join(annotation_root, f'{split}.txt'), 'r') as f: |
|
|
self.samples = f.read().splitlines() |
|
|
local_rank = None |
|
|
local_size = None |
|
|
self._consecutive_errors = 0 |
|
|
self.on_memory = on_memory |
|
|
if on_memory: |
|
|
self.holder = {} |
|
|
if local_rank is None: |
|
|
local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
|
|
if local_size is None: |
|
|
local_size = int(os.environ.get('LOCAL_SIZE', 1)) |
|
|
self.local_rank = local_rank |
|
|
self.local_size = local_size |
|
|
self.rank = int(os.environ['RANK']) |
|
|
self.world_size = int(os.environ['WORLD_SIZE']) |
|
|
self.num_replicas = int(os.environ['WORLD_SIZE']) |
|
|
self.num_parts = local_size |
|
|
self.num_samples = int( |
|
|
math.ceil(len(self.samples) * 1.0 / self.num_replicas)) |
|
|
self.total_size = self.num_samples * self.num_replicas |
|
|
self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts |
|
|
self.load_onto_memory_v2() |
|
|
|
|
|
def load_onto_memory(self): |
|
|
print('Loading images onto memory...', self.local_rank, |
|
|
self.local_size) |
|
|
if self.file_client is None: |
|
|
self.file_client = FileClient(self.io_backend, **self.kwargs) |
|
|
for index in trange(len(self.samples)): |
|
|
if index % self.local_size != self.local_rank: |
|
|
continue |
|
|
path, _ = self.samples[index].split(' ') |
|
|
path = osp.join(self.root, path) |
|
|
img_bytes = self.file_client.get(path) |
|
|
self.holder[path] = img_bytes |
|
|
|
|
|
print('Loading complete!') |
|
|
|
|
|
def load_onto_memory_v2(self): |
|
|
|
|
|
t = torch.Generator() |
|
|
t.manual_seed(0) |
|
|
indices = torch.randperm(len(self.samples), generator=t).tolist() |
|
|
|
|
|
indices = [i for i in indices if i % self.num_parts == self.local_rank] |
|
|
|
|
|
indices += indices[:(self.total_size_parts - len(indices))] |
|
|
assert len(indices) == self.total_size_parts |
|
|
|
|
|
|
|
|
indices = indices[self.rank // self.num_parts:self. |
|
|
total_size_parts:self.num_replicas // self.num_parts] |
|
|
assert len(indices) == self.num_samples |
|
|
|
|
|
if self.file_client is None: |
|
|
self.file_client = FileClient(self.io_backend, **self.kwargs) |
|
|
for index in tqdm(indices): |
|
|
if index % self.local_size != self.local_rank: |
|
|
continue |
|
|
path, _ = self.samples[index].split(' ') |
|
|
path = osp.join(self.root, path) |
|
|
img_bytes = self.file_client.get(path) |
|
|
|
|
|
self.holder[path] = img_bytes |
|
|
|
|
|
print('Loading complete!') |
|
|
|
|
|
def __getitem__(self, index): |
|
|
if self.file_client is None: |
|
|
self.file_client = FileClient(self.io_backend, **self.kwargs) |
|
|
|
|
|
filepath, target = self.samples[index].split(' ') |
|
|
filepath = osp.join(self.root, filepath) |
|
|
|
|
|
try: |
|
|
if self.on_memory: |
|
|
img_bytes = self.holder[filepath] |
|
|
else: |
|
|
|
|
|
img_bytes = self.file_client.get(filepath) |
|
|
img = mmcv.imfrombytes(img_bytes)[:, :, ::-1] |
|
|
except Exception as e: |
|
|
_logger.warning( |
|
|
f'Skipped sample (index {index}, file {filepath}). {str(e)}') |
|
|
self._consecutive_errors += 1 |
|
|
if self._consecutive_errors < _ERROR_RETRY: |
|
|
return self.__getitem__((index + 1) % len(self)) |
|
|
else: |
|
|
raise e |
|
|
self._consecutive_errors = 0 |
|
|
|
|
|
img = Image.fromarray(img) |
|
|
try: |
|
|
if self.class_to_idx is not None: |
|
|
target = self.class_to_idx[target] |
|
|
else: |
|
|
target = int(target) |
|
|
except: |
|
|
print(filepath, target) |
|
|
exit() |
|
|
|
|
|
return img, target |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def _filename(self, index, basename=False, absolute=False): |
|
|
filename, _ = self.samples[index].split(' ') |
|
|
filename = osp.join(self.root, filename) |
|
|
|
|
|
return filename |
|
|
|
|
|
|
|
|
def get_temporal_info(date, miss_hour=False): |
|
|
try: |
|
|
if date: |
|
|
if miss_hour: |
|
|
pattern = re.compile(r'(\d*)-(\d*)-(\d*)', re.I) |
|
|
else: |
|
|
pattern = re.compile(r'(\d*)-(\d*)-(\d*) (\d*):(\d*):(\d*)', |
|
|
re.I) |
|
|
m = pattern.match(date.strip()) |
|
|
|
|
|
if m: |
|
|
year = int(m.group(1)) |
|
|
month = int(m.group(2)) |
|
|
day = int(m.group(3)) |
|
|
x_month = math.sin(2 * math.pi * month / 12) |
|
|
y_month = math.cos(2 * math.pi * month / 12) |
|
|
if miss_hour: |
|
|
x_hour = 0 |
|
|
y_hour = 0 |
|
|
else: |
|
|
hour = int(m.group(4)) |
|
|
x_hour = math.sin(2 * math.pi * hour / 24) |
|
|
y_hour = math.cos(2 * math.pi * hour / 24) |
|
|
return [x_month, y_month, x_hour, y_hour] |
|
|
else: |
|
|
return [0, 0, 0, 0] |
|
|
else: |
|
|
return [0, 0, 0, 0] |
|
|
except: |
|
|
return [0, 0, 0, 0] |
|
|
|
|
|
|
|
|
def get_spatial_info(latitude, longitude): |
|
|
if latitude and longitude: |
|
|
latitude = math.radians(latitude) |
|
|
longitude = math.radians(longitude) |
|
|
x = math.cos(latitude) * math.cos(longitude) |
|
|
y = math.cos(latitude) * math.sin(longitude) |
|
|
z = math.sin(latitude) |
|
|
return [x, y, z] |
|
|
else: |
|
|
return [0, 0, 0] |
|
|
|