|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Data generators for Scenic.""" |
|
|
|
|
|
import functools |
|
|
import importlib |
|
|
from typing import Callable, List |
|
|
|
|
|
from absl import logging |
|
|
from scenic.dataset_lib import dataset_utils |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_IMPORT_TABLE = { |
|
|
'cifar10': 'scenic.dataset_lib.cifar10_dataset', |
|
|
'cityscapes': 'scenic.dataset_lib.cityscapes_dataset', |
|
|
'imagenet': 'scenic.dataset_lib.imagenet_dataset', |
|
|
'fashion_mnist': 'scenic.dataset_lib.fashion_mnist_dataset', |
|
|
'mnist': 'scenic.dataset_lib.mnist_dataset', |
|
|
'bair': 'scenic.dataset_lib.bair_dataset', |
|
|
'oxford_pets': 'scenic.dataset_lib.oxford_pets_dataset', |
|
|
'svhn': 'scenic.dataset_lib.svhn_dataset', |
|
|
'video_tfrecord_dataset': ( |
|
|
'scenic.projects.vivit.data.video_tfrecord_dataset' |
|
|
), |
|
|
'av_asr_tfrecord_dataset': ( |
|
|
'scenic.projects.avatar.datasets.av_asr_tfrecord_dataset' |
|
|
), |
|
|
'bit': 'scenic.dataset_lib.big_transfer.bit', |
|
|
'bert_wikibooks': ( |
|
|
'scenic.projects.baselines.bert.datasets.bert_wikibooks_dataset' |
|
|
), |
|
|
'bert_glue': 'scenic.projects.baselines.bert.datasets.bert_glue_dataset', |
|
|
'coco_detr_detection': ( |
|
|
'scenic.projects.baselines.detr.input_pipeline_detection' |
|
|
), |
|
|
'cityscapes_variants': ( |
|
|
'scenic.projects.robust_segvit.datasets.cityscapes_variants' |
|
|
), |
|
|
'robust_segvit_segmentation': ( |
|
|
'scenic.projects.robust_segvit.datasets.segmentation_datasets' |
|
|
), |
|
|
'robust_segvit_variants': ( |
|
|
'scenic.projects.robust_segvit.datasets.segmentation_variants' |
|
|
), |
|
|
'flexio': 'scenic.dataset_lib.flexio.flexio', |
|
|
} |
|
|
|
|
|
|
|
|
class DatasetRegistry(object): |
|
|
"""Static class for keeping track of available datasets.""" |
|
|
_REGISTRY = {} |
|
|
|
|
|
@classmethod |
|
|
def add(cls, name: str, builder_fn: Callable[..., dataset_utils.Dataset]): |
|
|
"""Add a dataset to the registry, i.e. register a dataset. |
|
|
|
|
|
Args: |
|
|
name: Dataset name (must be unique). |
|
|
builder_fn: Function to be called to construct the datasets. Must accept |
|
|
dataset-specific arguments and return a dataset description. |
|
|
|
|
|
Raises: |
|
|
KeyError: If the provided name is not unique. |
|
|
""" |
|
|
if name in cls._REGISTRY: |
|
|
raise KeyError(f'Dataset with name ({name}) already registered.') |
|
|
cls._REGISTRY[name] = builder_fn |
|
|
|
|
|
@classmethod |
|
|
def get(cls, name: str) -> Callable[..., dataset_utils.Dataset]: |
|
|
"""Get a dataset from the registry by its name. |
|
|
|
|
|
Args: |
|
|
name: Dataset name. |
|
|
|
|
|
Returns: |
|
|
Dataset builder function that accepts dataset-specific parameters and |
|
|
returns a dataset description. |
|
|
|
|
|
Raises: |
|
|
KeyError: If the dataset is not found. |
|
|
""" |
|
|
if name not in cls._REGISTRY: |
|
|
if name in _IMPORT_TABLE: |
|
|
module = _IMPORT_TABLE[name] |
|
|
importlib.import_module(module) |
|
|
logging.info( |
|
|
'On-demand import of dataset (%s) from module (%s).', name, module) |
|
|
if name not in cls._REGISTRY: |
|
|
raise KeyError(f'Imported module ({module}) did not register dataset' |
|
|
f'({name}). Please check that dataset names match.') |
|
|
else: |
|
|
raise KeyError(f'Unknown dataset ({name}). Did you import the dataset ' |
|
|
f'module explicitly?') |
|
|
return cls._REGISTRY[name] |
|
|
|
|
|
@classmethod |
|
|
def list(cls) -> List[str]: |
|
|
"""List registered datasets.""" |
|
|
return list(cls._REGISTRY.keys()) |
|
|
|
|
|
|
|
|
def add_dataset(name: str, *args, **kwargs): |
|
|
"""Decorator for shorthand dataset registdation.""" |
|
|
def inner(builder_fn: Callable[..., dataset_utils.Dataset] |
|
|
) -> Callable[..., dataset_utils.Dataset]: |
|
|
DatasetRegistry.add(name, functools.partial(builder_fn, *args, **kwargs)) |
|
|
return builder_fn |
|
|
return inner |
|
|
|
|
|
|
|
|
def get_dataset(dataset_name: str) -> Callable[..., dataset_utils.Dataset]: |
|
|
"""Maps dataset name to a dataset_builder. |
|
|
|
|
|
API kept for compatibility of existing code with the DatasetRegistry. |
|
|
|
|
|
Args: |
|
|
dataset_name: Dataset name. |
|
|
|
|
|
Returns: |
|
|
A dataset builder. |
|
|
""" |
|
|
return DatasetRegistry.get(dataset_name) |
|
|
|