| | |
| | import itertools |
| | import logging |
| | import numpy as np |
| | import operator |
| | import pickle |
| | from collections import OrderedDict, defaultdict |
| | from typing import Any, Callable, Dict, List, Optional, Union |
| | import torch |
| | import torch.utils.data as torchdata |
| | from tabulate import tabulate |
| | from termcolor import colored |
| |
|
| | from detectron2.config import configurable |
| | from detectron2.structures import BoxMode |
| | from detectron2.utils.comm import get_world_size |
| | from detectron2.utils.env import seed_all_rng |
| | from detectron2.utils.file_io import PathManager |
| | from detectron2.utils.logger import _log_api_usage, log_first_n |
| |
|
| | from .catalog import DatasetCatalog, MetadataCatalog |
| | from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset |
| | from .dataset_mapper import DatasetMapper |
| | from .detection_utils import check_metadata_consistency |
| | from .samplers import ( |
| | InferenceSampler, |
| | RandomSubsetTrainingSampler, |
| | RepeatFactorTrainingSampler, |
| | TrainingSampler, |
| | ) |
| |
|
| | """ |
| | This file contains the default logic to build a dataloader for training or testing. |
| | """ |
| |
|
| | __all__ = [ |
| | "build_batch_data_loader", |
| | "build_detection_train_loader", |
| | "build_detection_test_loader", |
| | "get_detection_dataset_dicts", |
| | "load_proposals_into_dataset", |
| | "print_instances_class_histogram", |
| | ] |
| |
|
| |
|
| | def filter_images_with_only_crowd_annotations(dataset_dicts): |
| | """ |
| | Filter out images with none annotations or only crowd annotations |
| | (i.e., images without non-crowd annotations). |
| | A common training-time preprocessing on COCO dataset. |
| | |
| | Args: |
| | dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. |
| | |
| | Returns: |
| | list[dict]: the same format, but filtered. |
| | """ |
| | num_before = len(dataset_dicts) |
| |
|
| | def valid(anns): |
| | for ann in anns: |
| | if ann.get("iscrowd", 0) == 0: |
| | return True |
| | return False |
| |
|
| | dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])] |
| | num_after = len(dataset_dicts) |
| | logger = logging.getLogger(__name__) |
| | logger.info( |
| | "Removed {} images with no usable annotations. {} images left.".format( |
| | num_before - num_after, num_after |
| | ) |
| | ) |
| | return dataset_dicts |
| |
|
| |
|
| | def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image): |
| | """ |
| | Filter out images with too few number of keypoints. |
| | |
| | Args: |
| | dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. |
| | |
| | Returns: |
| | list[dict]: the same format as dataset_dicts, but filtered. |
| | """ |
| | num_before = len(dataset_dicts) |
| |
|
| | def visible_keypoints_in_image(dic): |
| | |
| | annotations = dic["annotations"] |
| | return sum( |
| | (np.array(ann["keypoints"][2::3]) > 0).sum() |
| | for ann in annotations |
| | if "keypoints" in ann |
| | ) |
| |
|
| | dataset_dicts = [ |
| | x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image |
| | ] |
| | num_after = len(dataset_dicts) |
| | logger = logging.getLogger(__name__) |
| | logger.info( |
| | "Removed {} images with fewer than {} keypoints.".format( |
| | num_before - num_after, min_keypoints_per_image |
| | ) |
| | ) |
| | return dataset_dicts |
| |
|
| |
|
| | def load_proposals_into_dataset(dataset_dicts, proposal_file): |
| | """ |
| | Load precomputed object proposals into the dataset. |
| | |
| | The proposal file should be a pickled dict with the following keys: |
| | |
| | - "ids": list[int] or list[str], the image ids |
| | - "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id |
| | - "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores |
| | corresponding to the boxes. |
| | - "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``. |
| | |
| | Args: |
| | dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. |
| | proposal_file (str): file path of pre-computed proposals, in pkl format. |
| | |
| | Returns: |
| | list[dict]: the same format as dataset_dicts, but added proposal field. |
| | """ |
| | logger = logging.getLogger(__name__) |
| | logger.info("Loading proposals from: {}".format(proposal_file)) |
| |
|
| | with PathManager.open(proposal_file, "rb") as f: |
| | proposals = pickle.load(f, encoding="latin1") |
| |
|
| | |
| | rename_keys = {"indexes": "ids", "scores": "objectness_logits"} |
| | for key in rename_keys: |
| | if key in proposals: |
| | proposals[rename_keys[key]] = proposals.pop(key) |
| |
|
| | |
| | |
| | img_ids = set({str(record["image_id"]) for record in dataset_dicts}) |
| | id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids} |
| |
|
| | |
| | bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS |
| |
|
| | for record in dataset_dicts: |
| | |
| | i = id_to_index[str(record["image_id"])] |
| |
|
| | boxes = proposals["boxes"][i] |
| | objectness_logits = proposals["objectness_logits"][i] |
| | |
| | inds = objectness_logits.argsort()[::-1] |
| | record["proposal_boxes"] = boxes[inds] |
| | record["proposal_objectness_logits"] = objectness_logits[inds] |
| | record["proposal_bbox_mode"] = bbox_mode |
| |
|
| | return dataset_dicts |
| |
|
| |
|
| | def print_instances_class_histogram(dataset_dicts, class_names): |
| | """ |
| | Args: |
| | dataset_dicts (list[dict]): list of dataset dicts. |
| | class_names (list[str]): list of class names (zero-indexed). |
| | """ |
| | num_classes = len(class_names) |
| | hist_bins = np.arange(num_classes + 1) |
| | histogram = np.zeros((num_classes,), dtype=int) |
| | for entry in dataset_dicts: |
| | annos = entry["annotations"] |
| | classes = np.asarray( |
| | [x["category_id"] for x in annos if not x.get("iscrowd", 0)], dtype=int |
| | ) |
| | if len(classes): |
| | assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}" |
| | assert ( |
| | classes.max() < num_classes |
| | ), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes" |
| | histogram += np.histogram(classes, bins=hist_bins)[0] |
| |
|
| | N_COLS = min(6, len(class_names) * 2) |
| |
|
| | def short_name(x): |
| | |
| | if len(x) > 13: |
| | return x[:11] + ".." |
| | return x |
| |
|
| | data = list( |
| | itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)]) |
| | ) |
| | total_num_instances = sum(data[1::2]) |
| | data.extend([None] * (N_COLS - (len(data) % N_COLS))) |
| | if num_classes > 1: |
| | data.extend(["total", total_num_instances]) |
| | data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)]) |
| | table = tabulate( |
| | data, |
| | headers=["category", "#instances"] * (N_COLS // 2), |
| | tablefmt="pipe", |
| | numalign="left", |
| | stralign="center", |
| | ) |
| | log_first_n( |
| | logging.INFO, |
| | "Distribution of instances among all {} categories:\n".format(num_classes) |
| | + colored(table, "cyan"), |
| | key="message", |
| | ) |
| |
|
| |
|
| | def get_detection_dataset_dicts( |
| | names, |
| | filter_empty=True, |
| | min_keypoints=0, |
| | proposal_files=None, |
| | check_consistency=True, |
| | ): |
| | """ |
| | Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. |
| | |
| | Args: |
| | names (str or list[str]): a dataset name or a list of dataset names |
| | filter_empty (bool): whether to filter out images without instance annotations |
| | min_keypoints (int): filter out images with fewer keypoints than |
| | `min_keypoints`. Set to 0 to do nothing. |
| | proposal_files (list[str]): if given, a list of object proposal files |
| | that match each dataset in `names`. |
| | check_consistency (bool): whether to check if datasets have consistent metadata. |
| | |
| | Returns: |
| | list[dict]: a list of dicts following the standard dataset dict format. |
| | """ |
| | if isinstance(names, str): |
| | names = [names] |
| | assert len(names), names |
| |
|
| | available_datasets = DatasetCatalog.keys() |
| | names_set = set(names) |
| | if not names_set.issubset(available_datasets): |
| | logger = logging.getLogger(__name__) |
| | logger.warning( |
| | "The following dataset names are not registered in the DatasetCatalog: " |
| | f"{names_set - available_datasets}. " |
| | f"Available datasets are {available_datasets}" |
| | ) |
| |
|
| | dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names] |
| |
|
| | if isinstance(dataset_dicts[0], torchdata.Dataset): |
| | if len(dataset_dicts) > 1: |
| | |
| | |
| | |
| | return torchdata.ConcatDataset(dataset_dicts) |
| | return dataset_dicts[0] |
| |
|
| | for dataset_name, dicts in zip(names, dataset_dicts): |
| | assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) |
| |
|
| | if proposal_files is not None: |
| | assert len(names) == len(proposal_files) |
| | |
| | dataset_dicts = [ |
| | load_proposals_into_dataset(dataset_i_dicts, proposal_file) |
| | for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) |
| | ] |
| |
|
| | dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) |
| |
|
| | has_instances = "annotations" in dataset_dicts[0] |
| | if filter_empty and has_instances: |
| | dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) |
| | if min_keypoints > 0 and has_instances: |
| | dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) |
| |
|
| | if check_consistency and has_instances: |
| | try: |
| | class_names = MetadataCatalog.get(names[0]).thing_classes |
| | check_metadata_consistency("thing_classes", names) |
| | print_instances_class_histogram(dataset_dicts, class_names) |
| | except AttributeError: |
| | pass |
| |
|
| | assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names)) |
| | return dataset_dicts |
| |
|
| |
|
| | def build_batch_data_loader( |
| | dataset, |
| | sampler, |
| | total_batch_size, |
| | *, |
| | aspect_ratio_grouping=False, |
| | num_workers=0, |
| | collate_fn=None, |
| | drop_last: bool = True, |
| | single_gpu_batch_size=None, |
| | seed=None, |
| | **kwargs, |
| | ): |
| | """ |
| | Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are: |
| | 1. support aspect ratio grouping options |
| | 2. use no "batch collation", because this is common for detection training |
| | |
| | Args: |
| | dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset. |
| | sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices. |
| | Must be provided iff. ``dataset`` is a map-style dataset. |
| | total_batch_size, aspect_ratio_grouping, num_workers, collate_fn: see |
| | :func:`build_detection_train_loader`. |
| | single_gpu_batch_size: You can specify either `single_gpu_batch_size` or `total_batch_size`. |
| | `single_gpu_batch_size` specifies the batch size that will be used for each gpu/process. |
| | `total_batch_size` allows you to specify the total aggregate batch size across gpus. |
| | It is an error to supply a value for both. |
| | drop_last (bool): if ``True``, the dataloader will drop incomplete batches. |
| | |
| | Returns: |
| | iterable[list]. Length of each list is the batch size of the current |
| | GPU. Each element in the list comes from the dataset. |
| | """ |
| | if single_gpu_batch_size: |
| | if total_batch_size: |
| | raise ValueError( |
| | """total_batch_size and single_gpu_batch_size are mutually incompatible. |
| | Please specify only one. """ |
| | ) |
| | batch_size = single_gpu_batch_size |
| | else: |
| | world_size = get_world_size() |
| | assert ( |
| | total_batch_size > 0 and total_batch_size % world_size == 0 |
| | ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format( |
| | total_batch_size, world_size |
| | ) |
| | batch_size = total_batch_size // world_size |
| | logger = logging.getLogger(__name__) |
| | logger.info("Making batched data loader with batch_size=%d", batch_size) |
| |
|
| | if isinstance(dataset, torchdata.IterableDataset): |
| | assert sampler is None, "sampler must be None if dataset is IterableDataset" |
| | else: |
| | dataset = ToIterableDataset(dataset, sampler, shard_chunk_size=batch_size) |
| |
|
| | generator = None |
| | if seed is not None: |
| | generator = torch.Generator() |
| | generator.manual_seed(seed) |
| |
|
| | if aspect_ratio_grouping: |
| | assert drop_last, "Aspect ratio grouping will drop incomplete batches." |
| | data_loader = torchdata.DataLoader( |
| | dataset, |
| | num_workers=num_workers, |
| | collate_fn=operator.itemgetter(0), |
| | worker_init_fn=worker_init_reset_seed, |
| | generator=generator, |
| | **kwargs |
| | ) |
| | data_loader = AspectRatioGroupedDataset(data_loader, batch_size) |
| | if collate_fn is None: |
| | return data_loader |
| | return MapDataset(data_loader, collate_fn) |
| | else: |
| | return torchdata.DataLoader( |
| | dataset, |
| | batch_size=batch_size, |
| | drop_last=drop_last, |
| | num_workers=num_workers, |
| | collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, |
| | worker_init_fn=worker_init_reset_seed, |
| | generator=generator, |
| | **kwargs |
| | ) |
| |
|
| |
|
| | def _get_train_datasets_repeat_factors(cfg) -> Dict[str, float]: |
| | repeat_factors = cfg.DATASETS.TRAIN_REPEAT_FACTOR |
| | assert all(len(tup) == 2 for tup in repeat_factors) |
| | name_to_weight = defaultdict(lambda: 1, dict(repeat_factors)) |
| | |
| | unrecognized = set(name_to_weight.keys()) - set(cfg.DATASETS.TRAIN) |
| | assert not unrecognized, f"unrecognized datasets: {unrecognized}" |
| | logger = logging.getLogger(__name__) |
| | logger.info(f"Found repeat factors: {list(name_to_weight.items())}") |
| |
|
| | |
| | return name_to_weight |
| |
|
| |
|
| | def _build_weighted_sampler(cfg, enable_category_balance=False): |
| | dataset_repeat_factors = _get_train_datasets_repeat_factors(cfg) |
| | |
| | dataset_name_to_dicts = OrderedDict( |
| | { |
| | name: get_detection_dataset_dicts( |
| | [name], |
| | filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, |
| | min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE |
| | if cfg.MODEL.KEYPOINT_ON |
| | else 0, |
| | proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN |
| | if cfg.MODEL.LOAD_PROPOSALS |
| | else None, |
| | ) |
| | for name in cfg.DATASETS.TRAIN |
| | } |
| | ) |
| | |
| | repeat_factors = [ |
| | [dataset_repeat_factors[dsname]] * len(dataset_name_to_dicts[dsname]) |
| | for dsname in cfg.DATASETS.TRAIN |
| | ] |
| |
|
| | repeat_factors = list(itertools.chain.from_iterable(repeat_factors)) |
| |
|
| | repeat_factors = torch.tensor(repeat_factors) |
| | logger = logging.getLogger(__name__) |
| | if enable_category_balance: |
| | """ |
| | 1. Calculate repeat factors using category frequency for each dataset and then merge them. |
| | 2. Element wise dot producting the dataset frequency repeat factors with |
| | the category frequency repeat factors gives the final repeat factors. |
| | """ |
| | category_repeat_factors = [ |
| | RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( |
| | dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD |
| | ) |
| | for dataset_dict in dataset_name_to_dicts.values() |
| | ] |
| | |
| | category_repeat_factors = list(itertools.chain.from_iterable(category_repeat_factors)) |
| | category_repeat_factors = torch.tensor(category_repeat_factors) |
| | repeat_factors = torch.mul(category_repeat_factors, repeat_factors) |
| | repeat_factors = repeat_factors / torch.min(repeat_factors) |
| | logger.info( |
| | "Using WeightedCategoryTrainingSampler with repeat_factors={}".format( |
| | cfg.DATASETS.TRAIN_REPEAT_FACTOR |
| | ) |
| | ) |
| | else: |
| | logger.info( |
| | "Using WeightedTrainingSampler with repeat_factors={}".format( |
| | cfg.DATASETS.TRAIN_REPEAT_FACTOR |
| | ) |
| | ) |
| |
|
| | sampler = RepeatFactorTrainingSampler(repeat_factors) |
| | return sampler |
| |
|
| |
|
| | def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): |
| | if dataset is None: |
| | dataset = get_detection_dataset_dicts( |
| | cfg.DATASETS.TRAIN, |
| | filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, |
| | min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE |
| | if cfg.MODEL.KEYPOINT_ON |
| | else 0, |
| | proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, |
| | ) |
| | _log_api_usage("dataset." + cfg.DATASETS.TRAIN[0]) |
| |
|
| | if mapper is None: |
| | mapper = DatasetMapper(cfg, True) |
| |
|
| | if sampler is None: |
| | sampler_name = cfg.DATALOADER.SAMPLER_TRAIN |
| | logger = logging.getLogger(__name__) |
| | if isinstance(dataset, torchdata.IterableDataset): |
| | logger.info("Not using any sampler since the dataset is IterableDataset.") |
| | sampler = None |
| | else: |
| | logger.info("Using training sampler {}".format(sampler_name)) |
| | if sampler_name == "TrainingSampler": |
| | sampler = TrainingSampler(len(dataset)) |
| | elif sampler_name == "RepeatFactorTrainingSampler": |
| | repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( |
| | dataset, cfg.DATALOADER.REPEAT_THRESHOLD |
| | ) |
| | sampler = RepeatFactorTrainingSampler(repeat_factors) |
| | elif sampler_name == "RandomSubsetTrainingSampler": |
| | sampler = RandomSubsetTrainingSampler( |
| | len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO |
| | ) |
| | elif sampler_name == "WeightedTrainingSampler": |
| | sampler = _build_weighted_sampler(cfg) |
| | elif sampler_name == "WeightedCategoryTrainingSampler": |
| | sampler = _build_weighted_sampler(cfg, enable_category_balance=True) |
| | else: |
| | raise ValueError("Unknown training sampler: {}".format(sampler_name)) |
| |
|
| | return { |
| | "dataset": dataset, |
| | "sampler": sampler, |
| | "mapper": mapper, |
| | "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, |
| | "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, |
| | "num_workers": cfg.DATALOADER.NUM_WORKERS, |
| | } |
| |
|
| |
|
| | @configurable(from_config=_train_loader_from_config) |
| | def build_detection_train_loader( |
| | dataset, |
| | *, |
| | mapper, |
| | sampler=None, |
| | total_batch_size, |
| | aspect_ratio_grouping=True, |
| | num_workers=0, |
| | collate_fn=None, |
| | **kwargs |
| | ): |
| | """ |
| | Build a dataloader for object detection with some default features. |
| | |
| | Args: |
| | dataset (list or torch.utils.data.Dataset): a list of dataset dicts, |
| | or a pytorch dataset (either map-style or iterable). It can be obtained |
| | by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. |
| | mapper (callable): a callable which takes a sample (dict) from dataset and |
| | returns the format to be consumed by the model. |
| | When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``. |
| | sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces |
| | indices to be applied on ``dataset``. |
| | If ``dataset`` is map-style, the default sampler is a :class:`TrainingSampler`, |
| | which coordinates an infinite random shuffle sequence across all workers. |
| | Sampler must be None if ``dataset`` is iterable. |
| | total_batch_size (int): total batch size across all workers. |
| | aspect_ratio_grouping (bool): whether to group images with similar |
| | aspect ratio for efficiency. When enabled, it requires each |
| | element in dataset be a dict with keys "width" and "height". |
| | num_workers (int): number of parallel data loading workers |
| | collate_fn: a function that determines how to do batching, same as the argument of |
| | `torch.utils.data.DataLoader`. Defaults to do no collation and return a list of |
| | data. No collation is OK for small batch size and simple data structures. |
| | If your batch size is large and each sample contains too many small tensors, |
| | it's more efficient to collate them in data loader. |
| | |
| | Returns: |
| | torch.utils.data.DataLoader: |
| | a dataloader. Each output from it is a ``list[mapped_element]`` of length |
| | ``total_batch_size / num_workers``, where ``mapped_element`` is produced |
| | by the ``mapper``. |
| | """ |
| | if isinstance(dataset, list): |
| | dataset = DatasetFromList(dataset, copy=False) |
| | if mapper is not None: |
| | dataset = MapDataset(dataset, mapper) |
| |
|
| | if isinstance(dataset, torchdata.IterableDataset): |
| | assert sampler is None, "sampler must be None if dataset is IterableDataset" |
| | else: |
| | if sampler is None: |
| | sampler = TrainingSampler(len(dataset)) |
| | assert isinstance(sampler, torchdata.Sampler), f"Expect a Sampler but got {type(sampler)}" |
| | return build_batch_data_loader( |
| | dataset, |
| | sampler, |
| | total_batch_size, |
| | aspect_ratio_grouping=aspect_ratio_grouping, |
| | num_workers=num_workers, |
| | collate_fn=collate_fn, |
| | **kwargs |
| | ) |
| |
|
| |
|
| | def _test_loader_from_config(cfg, dataset_name, mapper=None): |
| | """ |
| | Uses the given `dataset_name` argument (instead of the names in cfg), because the |
| | standard practice is to evaluate each test set individually (not combining them). |
| | """ |
| | if isinstance(dataset_name, str): |
| | dataset_name = [dataset_name] |
| |
|
| | dataset = get_detection_dataset_dicts( |
| | dataset_name, |
| | filter_empty=False, |
| | proposal_files=[ |
| | cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name |
| | ] |
| | if cfg.MODEL.LOAD_PROPOSALS |
| | else None, |
| | ) |
| | if mapper is None: |
| | mapper = DatasetMapper(cfg, False) |
| | return { |
| | "dataset": dataset, |
| | "mapper": mapper, |
| | "num_workers": cfg.DATALOADER.NUM_WORKERS, |
| | "sampler": InferenceSampler(len(dataset)) |
| | if not isinstance(dataset, torchdata.IterableDataset) |
| | else None, |
| | } |
| |
|
| |
|
| | @configurable(from_config=_test_loader_from_config) |
| | def build_detection_test_loader( |
| | dataset: Union[List[Any], torchdata.Dataset], |
| | *, |
| | mapper: Callable[[Dict[str, Any]], Any], |
| | sampler: Optional[torchdata.Sampler] = None, |
| | batch_size: int = 1, |
| | num_workers: int = 0, |
| | collate_fn: Optional[Callable[[List[Any]], Any]] = None, |
| | ) -> torchdata.DataLoader: |
| | """ |
| | Similar to `build_detection_train_loader`, with default batch size = 1, |
| | and sampler = :class:`InferenceSampler`. This sampler coordinates all workers |
| | to produce the exact set of all samples. |
| | |
| | Args: |
| | dataset: a list of dataset dicts, |
| | or a pytorch dataset (either map-style or iterable). They can be obtained |
| | by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. |
| | mapper: a callable which takes a sample (dict) from dataset |
| | and returns the format to be consumed by the model. |
| | When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. |
| | sampler: a sampler that produces |
| | indices to be applied on ``dataset``. Default to :class:`InferenceSampler`, |
| | which splits the dataset across all workers. Sampler must be None |
| | if `dataset` is iterable. |
| | batch_size: the batch size of the data loader to be created. |
| | Default to 1 image per worker since this is the standard when reporting |
| | inference time in papers. |
| | num_workers: number of parallel data loading workers |
| | collate_fn: same as the argument of `torch.utils.data.DataLoader`. |
| | Defaults to do no collation and return a list of data. |
| | |
| | Returns: |
| | DataLoader: a torch DataLoader, that loads the given detection |
| | dataset, with test-time transformation and batching. |
| | |
| | Examples: |
| | :: |
| | data_loader = build_detection_test_loader( |
| | DatasetRegistry.get("my_test"), |
| | mapper=DatasetMapper(...)) |
| | |
| | # or, instantiate with a CfgNode: |
| | data_loader = build_detection_test_loader(cfg, "my_test") |
| | """ |
| | if isinstance(dataset, list): |
| | dataset = DatasetFromList(dataset, copy=False) |
| | if mapper is not None: |
| | dataset = MapDataset(dataset, mapper) |
| | if isinstance(dataset, torchdata.IterableDataset): |
| | assert sampler is None, "sampler must be None if dataset is IterableDataset" |
| | else: |
| | if sampler is None: |
| | sampler = InferenceSampler(len(dataset)) |
| | return torchdata.DataLoader( |
| | dataset, |
| | batch_size=batch_size, |
| | sampler=sampler, |
| | drop_last=False, |
| | num_workers=num_workers, |
| | collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, |
| | ) |
| |
|
| |
|
| | def trivial_batch_collator(batch): |
| | """ |
| | A batch collator that does nothing. |
| | """ |
| | return batch |
| |
|
| |
|
| | def worker_init_reset_seed(worker_id): |
| | initial_seed = torch.initial_seed() % 2**31 |
| | seed_all_rng(initial_seed + worker_id) |
| |
|