Spaces:
Running
Running
| from typing import Callable, Union | |
| import os | |
| import re | |
| import math | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| import torch.utils.data as data | |
| from torchvision import transforms | |
| class ToNumpy: | |
| def __call__(self, pil_img): | |
| np_img = np.array(pil_img, dtype=np.uint8) | |
| if np_img.ndim < 3: | |
| np_img = np.expand_dims(np_img, axis=-1) | |
| np_img = np.rollaxis(np_img, 2) # HWC to CHW | |
| return np_img | |
| def _pil_interp(method): | |
| if method == 'bicubic': | |
| return Image.BICUBIC | |
| elif method == 'lanczos': | |
| return Image.LANCZOS | |
| elif method == 'hamming': | |
| return Image.HAMMING | |
| else: | |
| # default bilinear, do we want to allow nearest? | |
| return Image.BILINEAR | |
| def natural_key(string_): | |
| return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] | |
| def find_images_and_targets(folder, types=('.png', '.jpg', '.jpeg'), class_to_idx=None, leaf_name_only=True, sort=True): | |
| labels = [] | |
| filenames = [] | |
| for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): | |
| rel_path = os.path.relpath(root, folder) if (root != folder) else '' | |
| label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') | |
| for f in files: | |
| base, ext = os.path.splitext(f) | |
| if ext.lower() in types: | |
| filenames.append(os.path.join(root, f)) | |
| labels.append(label) | |
| if class_to_idx is None: | |
| # building class index | |
| unique_labels = set(labels) | |
| sorted_labels = list(sorted(unique_labels, key=natural_key)) | |
| class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} | |
| images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] | |
| if sort: | |
| images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) | |
| return images_and_targets, class_to_idx | |
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
| DEFAULT_CROP_PCT = 0.875 | |
| def transforms_noaug_train( | |
| img_size=224, | |
| interpolation='bilinear', | |
| use_prefetcher=False, | |
| mean=IMAGENET_DEFAULT_MEAN, | |
| std=IMAGENET_DEFAULT_STD, | |
| ): | |
| if interpolation == 'random': | |
| # random interpolation not supported with no-aug | |
| interpolation = 'bilinear' | |
| tfl = [transforms.Resize(img_size, _pil_interp(interpolation)), transforms.CenterCrop(img_size)] | |
| if use_prefetcher: | |
| # prefetcher and collate will handle tensor conversion and norm | |
| tfl += [ToNumpy()] | |
| else: | |
| tfl += [transforms.ToTensor(), transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))] | |
| return transforms.Compose(tfl) | |
| def transforms_imagenet_eval( | |
| img_size=224, | |
| crop_pct=None, | |
| interpolation='bilinear', | |
| use_prefetcher=False, | |
| mean=IMAGENET_DEFAULT_MEAN, | |
| std=IMAGENET_DEFAULT_STD | |
| ): | |
| crop_pct = crop_pct or DEFAULT_CROP_PCT | |
| if isinstance(img_size, (tuple, list)): | |
| assert len(img_size) == 2 | |
| if img_size[-1] == img_size[-2]: | |
| # fall-back to older behaviour so Resize scales to shortest edge if target is square | |
| scale_size = int(math.floor(img_size[0] / crop_pct)) | |
| else: | |
| scale_size = tuple([int(x / crop_pct) for x in img_size]) | |
| else: | |
| scale_size = int(math.floor(img_size / crop_pct)) | |
| tfl = [ | |
| transforms.Resize(scale_size, _pil_interp(interpolation)), | |
| transforms.CenterCrop(img_size), | |
| ] | |
| if use_prefetcher: | |
| # prefetcher and collate will handle tensor conversion and norm | |
| tfl += [ToNumpy()] | |
| else: | |
| tfl += [transforms.ToTensor(), transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))] | |
| return transforms.Compose(tfl) | |
| class ImageNetDataset(data.Dataset): | |
| def __init__(self, root: str, is_training: bool, transform: Callable = None) -> None: | |
| self.root = root | |
| if transform is None: | |
| if is_training: | |
| transform = transforms_noaug_train() | |
| else: | |
| transform = transforms_imagenet_eval() | |
| self.transform = transform | |
| self.data, _ = find_images_and_targets(root) | |
| def __len__(self) -> int: | |
| return len(self.data) | |
| def __getitem__(self, index: int) -> Union[torch.Tensor, torch.Tensor]: | |
| img, target = self.data[index] | |
| img = Image.open(img).convert('RGB') | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| if target is None: | |
| target = torch.tensor(-1, dtype=torch.long) | |
| return img, target | |