Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Example classification decoder and parser. | |
| This file defines the Decoder and Parser to load data. The example is shown on | |
| loading standard tf.Example data but non-standard tf.Example or other data | |
| format can be supported by implementing proper decoder and parser. | |
| """ | |
| from typing import Mapping, List, Tuple | |
| # Import libraries | |
| import tensorflow as tf, tf_keras | |
| from official.vision.dataloaders import decoder | |
| from official.vision.dataloaders import parser | |
| from official.vision.ops import preprocess_ops | |
| class Decoder(decoder.Decoder): | |
| """A tf.Example decoder for classification task.""" | |
| def __init__(self): | |
| """Initializes the decoder. | |
| The constructor defines the mapping between the field name and the value | |
| from an input tf.Example. For example, we define two fields for image bytes | |
| and labels. There is no limit on the number of fields to decode. | |
| """ | |
| self._keys_to_features = { | |
| 'image/encoded': | |
| tf.io.FixedLenFeature((), tf.string, default_value=''), | |
| 'image/class/label': | |
| tf.io.FixedLenFeature((), tf.int64, default_value=-1) | |
| } | |
| def decode(self, | |
| serialized_example: tf.train.Example) -> Mapping[str, tf.Tensor]: | |
| """Decodes a tf.Example to a dictionary. | |
| This function decodes a serialized tf.Example to a dictionary. The output | |
| will be consumed by `_parse_train_data` and `_parse_validation_data` in | |
| Parser. | |
| Args: | |
| serialized_example: A serialized tf.Example. | |
| Returns: | |
| A dictionary of field key name and decoded tensor mapping. | |
| """ | |
| return tf.io.parse_single_example( | |
| serialized_example, self._keys_to_features) | |
| class Parser(parser.Parser): | |
| """Parser to parse an image and its annotations. | |
| To define own Parser, client should override _parse_train_data and | |
| _parse_eval_data functions, where decoded tensors are parsed with optional | |
| pre-processing steps. The output from the two functions can be any structure | |
| like tuple, list or dictionary. | |
| """ | |
| def __init__(self, output_size: List[int], num_classes: float): | |
| """Initializes parameters for parsing annotations in the dataset. | |
| This example only takes two arguments but one can freely add as many | |
| arguments as needed. For example, pre-processing and augmentations usually | |
| happen in Parser, and related parameters can be passed in by this | |
| constructor. | |
| Args: | |
| output_size: `Tensor` or `list` for [height, width] of output image. | |
| num_classes: `float`, number of classes. | |
| """ | |
| self._output_size = output_size | |
| self._num_classes = num_classes | |
| self._dtype = tf.float32 | |
| def _parse_data( | |
| self, decoded_tensors: Mapping[str, | |
| tf.Tensor]) -> Tuple[tf.Tensor, tf.Tensor]: | |
| label = tf.cast(decoded_tensors['image/class/label'], dtype=tf.int32) | |
| image_bytes = decoded_tensors['image/encoded'] | |
| image = tf.io.decode_jpeg(image_bytes, channels=3) | |
| image = tf.image.resize( | |
| image, self._output_size, method=tf.image.ResizeMethod.BILINEAR) | |
| image = tf.ensure_shape(image, self._output_size + [3]) | |
| # Normalizes image with mean and std pixel values. | |
| image = preprocess_ops.normalize_image( | |
| image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB) | |
| image = tf.image.convert_image_dtype(image, self._dtype) | |
| return image, label | |
| def _parse_train_data( | |
| self, decoded_tensors: Mapping[str, | |
| tf.Tensor]) -> Tuple[tf.Tensor, tf.Tensor]: | |
| """Parses data for training. | |
| Args: | |
| decoded_tensors: A dictionary of field key name and decoded tensor mapping | |
| from Decoder. | |
| Returns: | |
| A tuple of (image, label) tensors. | |
| """ | |
| return self._parse_data(decoded_tensors) | |
| def _parse_eval_data( | |
| self, decoded_tensors: Mapping[str, | |
| tf.Tensor]) -> Tuple[tf.Tensor, tf.Tensor]: | |
| """Parses data for evaluation. | |
| Args: | |
| decoded_tensors: A dictionary of field key name and decoded tensor mapping | |
| from Decoder. | |
| Returns: | |
| A tuple of (image, label) tensors. | |
| """ | |
| return self._parse_data(decoded_tensors) | |