Spaces:
Build error
Build error
| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import dataclasses | |
| import enum | |
| import subprocess | |
| from typing import Callable, Optional | |
| class Framework(enum.Enum): | |
| TENSORFLOW = "tensorflow" | |
| PYTORCH = "pytorch" | |
| class DatasetSpec: | |
| framework: Optional[Framework] | |
| create_dataset_fn: Callable | |
| COCO_LABELS = { | |
| # 0: 'background', | |
| 1: "person", | |
| 2: "bicycle", | |
| 3: "car", | |
| 4: "motorcycle", | |
| 5: "airplane", | |
| 6: "bus", | |
| 7: "train", | |
| 8: "truck", | |
| 9: "boat", | |
| 10: "traffic light", | |
| 11: "fire hydrant", | |
| 13: "stop sign", | |
| 14: "parking meter", | |
| 15: "bench", | |
| 16: "bird", | |
| 17: "cat", | |
| 18: "dog", | |
| 19: "horse", | |
| 20: "sheep", | |
| 21: "cow", | |
| 22: "elephant", | |
| 23: "bear", | |
| 24: "zebra", | |
| 25: "giraffe", | |
| 27: "backpack", | |
| 28: "umbrella", | |
| 31: "handbag", | |
| 32: "tie", | |
| 33: "suitcase", | |
| 34: "frisbee", | |
| 35: "skis", | |
| 36: "snowboard", | |
| 37: "sports ball", | |
| 38: "kite", | |
| 39: "baseball bat", | |
| 40: "baseball glove", | |
| 41: "skateboard", | |
| 42: "surfboard", | |
| 43: "tennis racket", | |
| 44: "bottle", | |
| 46: "wine glass", | |
| 47: "cup", | |
| 48: "fork", | |
| 49: "knife", | |
| 50: "spoon", | |
| 51: "bowl", | |
| 52: "banana", | |
| 53: "apple", | |
| 54: "sandwich", | |
| 55: "orange", | |
| 56: "broccoli", | |
| 57: "carrot", | |
| 58: "hot dog", | |
| 59: "pizza", | |
| 60: "donut", | |
| 61: "cake", | |
| 62: "chair", | |
| 63: "couch", | |
| 64: "potted plant", | |
| 65: "bed", | |
| 67: "dining table", | |
| 70: "toilet", | |
| 72: "tv", | |
| 73: "laptop", | |
| 74: "mouse", | |
| 75: "remote", | |
| 76: "keyboard", | |
| 77: "cell phone", | |
| 78: "microwave", | |
| 79: "oven", | |
| 80: "toaster", | |
| 81: "sink", | |
| 82: "refrigerator", | |
| 84: "book", | |
| 85: "clock", | |
| 86: "vase", | |
| 87: "scissors", | |
| 88: "teddy bear", | |
| 89: "hair drier", | |
| 90: "toothbrush", | |
| } | |
| def _create_tfds_coco2017_validation(batch_size: Optional[int] = None) -> Callable: | |
| subprocess.run(["pip", "install", "--upgrade", "tensorflow-datasets"], check=True) | |
| import tensorflow_datasets as tfds # pytype: disable=import-error | |
| return tfds.load("coco/2017", split="validation", as_supervised=True, with_info=True, batch_size=batch_size) | |
| TFDS_COCO2017_VALIDATION_DATASET = DatasetSpec( | |
| framework=Framework.TENSORFLOW, | |
| create_dataset_fn=_create_tfds_coco2017_validation, | |
| ) | |
| def _create_tfds_tf_flowers(batch_size: Optional[int] = None): | |
| subprocess.run(["pip", "install", "--upgrade", "tensorflow-datasets"], check=True) | |
| import tensorflow_datasets as tfds # pytype: disable=import-error | |
| return tfds.load("tf_flowers", as_supervised=True, with_info=True, batch_size=batch_size) | |
| TFDS_TF_FLOWERS_DATASET = DatasetSpec( | |
| framework=Framework.TENSORFLOW, | |
| create_dataset_fn=_create_tfds_tf_flowers, | |
| ) | |
| DATASETS_CATALOGUE = [TFDS_COCO2017_VALIDATION_DATASET, TFDS_TF_FLOWERS_DATASET] | |