| |
| |
| |
| |
| |
| |
| |
| |
| import os |
| import random |
|
|
| import torch |
| from timm.data import create_dataset |
| from timm.data.transforms_factory import (transforms_imagenet_eval, |
| transforms_imagenet_train) |
| from torch.utils.data import DataLoader, Subset |
|
|
| from common.registries.dataset_registry import DATASET_WRAPPER_REGISTRY |
| from common.utils import LOGGER |
| from image_classification.pt.src.datasets import prepare_kwargs_for_dataloader |
| from image_classification.pt.src.datasets.augmentations.augs import ( |
| DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) |
| from image_classification.pt.src.datasets.dataset_utils import ( |
| PredictionDataset, create_loader) |
|
|
| __all__ = ['get_custom'] |
|
|
|
|
| @DATASET_WRAPPER_REGISTRY.register(framework='torch', dataset_name='custom', use_case="image_classification") |
| def get_custom(cfg): |
| args = prepare_kwargs_for_dataloader(cfg) |
| |
| |
| if isinstance(args["device"], str): |
| args["device"] = torch.device(args["device"]) |
|
|
| train_loader = test_loader = val_loader = pred_loader = None |
| args["training_path"] = getattr(cfg.dataset,"training_path", None) |
| args["validation_path"] = getattr(cfg.dataset,"validation_path", None) |
|
|
| if args["training_path"]: |
| LOGGER.info(f"Loading training data from: {cfg.dataset.training_path}") |
| train_loader = create_training_dataset(args) |
| else: |
| LOGGER.info("No path available for training data") |
| if args["validation_path"]: |
| LOGGER.info(f"Loading validation data from: {cfg.dataset.validation_path}") |
| val_loader = create_validation_dataset(args) |
| else: |
| LOGGER.info("No path available for validation data") |
| if getattr(cfg.dataset, "test_path", None): |
| LOGGER.info(f"Loading test data from: {cfg.dataset.test_path}") |
| test_loader = create_test_dataset(args) |
| else: |
| LOGGER.info("No path available for test data") |
| |
| quant_loader = create_quantization_dataset(args) |
| |
| if getattr(cfg.dataset, "prediction_path", None): |
| LOGGER.info(f"Loading prediction data from {cfg.dataset.prediction_path}") |
| pred_loader = create_prediction_dataset(args) |
| else: |
| LOGGER.info("No path available for prediction data") |
|
|
| return {'train': train_loader, 'valid': val_loader, 'test': test_loader, 'quantization': quant_loader, 'predict': pred_loader} |
| |
| def create_training_dataset(args): |
| training_path = args["training_path"] |
| re_num_splits = 0 |
| if args["re_split"]: |
| |
| re_num_splits = args["num_aug_splits"] or 2 |
| img_size = args["img_size"] |
| |
| if isinstance(img_size, (tuple, list)): |
| img_size = img_size[-1] |
| default_train_transforms = transforms_imagenet_train( |
| img_size, |
| mean=args["mean"] or IMAGENET_DEFAULT_MEAN, |
| std=args["std"] or IMAGENET_DEFAULT_STD, |
| scale=args["scale"], |
| ratio=args["ratio"], |
| hflip=args["hflip"], |
| vflip=args["vflip"], |
| color_jitter=args["color_jitter"], |
| auto_augment=args["auto_augment"], |
| interpolation=args["train_interpolation"], |
| re_prob=args["re_prob"], |
| re_mode=args["re_mode"], |
| re_count=args["re_count"], |
| re_num_splits=re_num_splits, |
| use_prefetcher=args["use_prefetcher"], |
| ) |
| |
| dataset_train = create_dataset( |
| 'imagenet', |
| root=training_path, |
| |
| search_split=False, |
| is_training=True, |
| class_map=args["class_map"], |
| download=args["download"], |
| batch_size=args["batch_size"], |
| seed=args["seed"], |
| repeats=args["repeats"], |
| ) |
| |
| dataset_train.transform = args.get("train_transforms", default_train_transforms) |
| dataset_train.classes = range(args["num_classes"]) |
| |
| train_loader = create_loader( |
| dataset_train, |
| input_size=args["img_size"], |
| batch_size=args["batch_size"], |
| is_training=True, |
| use_prefetcher=args["use_prefetcher"], |
| no_aug=args["no_aug"], |
| re_prob=args["re_prob"], |
| re_mode=args["re_mode"], |
| re_count=args["re_count"], |
| num_aug_repeats=args["num_aug_repeats"], |
| re_num_splits=re_num_splits, |
| mean=args.get("mean") or IMAGENET_DEFAULT_MEAN, |
| std=args.get("std") or IMAGENET_DEFAULT_STD, |
| num_workers=args["num_workers"], |
| distributed=args["distributed"], |
| collate_fn=args["collate_fn"], |
| pin_memory=args["pin_memory"], |
| device=args["device"], |
| use_multi_epochs_loader=args["use_multi_epochs_loader"], |
| worker_seeding=args["worker_seeding"], |
| ) |
| return train_loader |
|
|
|
|
| def create_validation_dataset(args): |
| validation_path = args["validation_path"] |
| img_size = args["img_size"] |
| if isinstance(img_size, (tuple, list)): |
| img_size = img_size[-1] |
| |
| default_val_transforms = transforms_imagenet_eval( |
| img_size, |
| mean=args["mean"] or IMAGENET_DEFAULT_MEAN, |
| std=args["std"] or IMAGENET_DEFAULT_STD, |
| crop_pct=args.get("crop_pct") or DEFAULT_CROP_PCT, |
| interpolation=args["test_interpolation"], |
| use_prefetcher=args["use_prefetcher"], |
| ) |
| dataset_val = create_dataset( |
| 'imagenet', |
| root=validation_path, |
| |
| search_split=False, |
| is_training=False, |
| class_map=args["class_map"], |
| download=args["download"], |
| batch_size=args["batch_size"], |
| ) |
|
|
| dataset_val.transform=args.get("val_transforms", default_val_transforms) |
| val_loader = create_loader( |
| dataset_val, |
| input_size=args["img_size"], |
| batch_size=args.get("val_batch_size", args["batch_size"]), |
| is_training=False, |
| use_prefetcher=args["use_prefetcher"], |
| mean=args.get("mean") or IMAGENET_DEFAULT_MEAN, |
| std=args.get("std") or IMAGENET_DEFAULT_STD, |
| num_workers=args["num_workers"], |
| distributed=args["distributed"], |
| pin_memory=args["pin_memory"], |
| device=args["device"], |
| ) |
| return val_loader |
|
|
| def create_test_dataset(args): |
| |
| img_size = args["img_size"] |
| if isinstance(img_size, (tuple, list)): |
| img_size = img_size[-1] |
| |
| default_test_transforms = transforms_imagenet_eval( |
| img_size, |
| mean=args["mean"] or IMAGENET_DEFAULT_MEAN, |
| std=args["std"] or IMAGENET_DEFAULT_STD, |
| crop_pct=args.get("crop_pct") or DEFAULT_CROP_PCT, |
| interpolation=args["test_interpolation"], |
| use_prefetcher=args["use_prefetcher"], |
| ) |
| dataset_test = create_dataset( |
| 'imagenet', |
| root=args["test_path"], |
| |
| search_split=False, |
| is_training=False, |
| class_map=args["class_map"], |
| download=args["download"], |
| batch_size=args["batch_size"], |
| ) |
|
|
| dataset_test.transform=args.get("test_transforms", default_test_transforms) |
| test_loader = create_loader( |
| dataset_test, |
| input_size=args["img_size"], |
| batch_size=args.get("test_batch_size", args["batch_size"]), |
| is_training=False, |
| use_prefetcher=args["use_prefetcher"], |
| mean=args.get("mean") or IMAGENET_DEFAULT_MEAN, |
| std=args.get("std") or IMAGENET_DEFAULT_STD, |
| num_workers=args["num_workers"], |
| distributed=args["distributed"], |
| pin_memory=args["pin_memory"], |
| device=args["device"], |
| ) |
| return test_loader |
|
|
| def create_quantization_dataset(args): |
| re_num_splits = 0 |
| if args["re_split"]: |
| |
| re_num_splits = args["num_aug_splits"] or 2 |
| img_size = args["img_size"] |
|
|
| if isinstance(img_size, (tuple, list)): |
| img_size = img_size[-1] |
| default_train_transforms = transforms_imagenet_train( |
| img_size, |
| mean=args["mean"] or IMAGENET_DEFAULT_MEAN, |
| std=args["std"] or IMAGENET_DEFAULT_STD, |
| scale=args["scale"], |
| ratio=args["ratio"], |
| hflip=args["hflip"], |
| vflip=args["vflip"], |
| color_jitter=args["color_jitter"], |
| auto_augment=args["auto_augment"], |
| interpolation=args["train_interpolation"], |
| re_prob=args["re_prob"], |
| re_mode=args["re_mode"], |
| re_count=args["re_count"], |
| re_num_splits=re_num_splits, |
| use_prefetcher=args["use_prefetcher"], |
| ) |
| |
| if args.get("quantization_path"): |
| data_path = args["quantization_path"] |
| LOGGER.info(f"Loading quantization data from {data_path}") |
| elif args["training_path"]: |
| data_path = args["training_path"] |
| LOGGER.info(f"Loading quantization data from training data at: {data_path}") |
| else: |
| LOGGER.info("No path available for quantization data") |
| return None |
|
|
| dataset_train = create_dataset( |
| 'imagenet', |
| root=data_path, |
| |
| search_split=False, |
| is_training=True, |
| class_map=args["class_map"], |
| download=args["download"], |
| batch_size=args["batch_size"], |
| seed=args["seed"], |
| repeats=args["repeats"], |
| ) |
| |
| dataset_train.transform = args.get("train_transforms", default_train_transforms) |
| dataset_train.classes = range(args["num_classes"]) |
| |
| |
| quantization_split = args.get("quantization_split", 1.0) |
| if quantization_split == 1.0: |
| LOGGER.info("100 percent data is being used for quantization") |
|
|
| |
| num_quant_samples = int(len(dataset_train) * quantization_split) |
| quant_indices = random.sample(range(len(dataset_train)), min(num_quant_samples, len(dataset_train))) |
| quant_subset = Subset(dataset_train, quant_indices) |
| quant_loader = DataLoader( |
| quant_subset, |
| batch_size=1, |
| shuffle=False, |
| num_workers=args["num_workers"], |
| pin_memory=args["pin_memory"], |
| ) |
|
|
| return quant_loader |
|
|
| def create_prediction_dataset(args): |
| img_size = args["img_size"] |
| |
| if isinstance(img_size, (tuple, list)): |
| img_size = img_size[-1] |
| default_val_transforms = transforms_imagenet_eval( |
| img_size, |
| mean=args["mean"] or IMAGENET_DEFAULT_MEAN, |
| std=args["std"] or IMAGENET_DEFAULT_STD, |
| crop_pct=args.get("crop_pct") or DEFAULT_CROP_PCT, |
| interpolation=args["test_interpolation"], |
| use_prefetcher=args["use_prefetcher"], |
| ) |
| dataset_pred = PredictionDataset(args["prediction_path"], default_val_transforms) |
| pred_loader = DataLoader( |
| dataset_pred, |
| batch_size=1, |
| shuffle=False, |
| num_workers=args["num_workers"], |
| pin_memory=args["pin_memory"], |
| ) |
| return pred_loader |