Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Classification decoder and parser.""" | |
| from typing import Any, Dict, List, Optional, Tuple | |
| # Import libraries | |
| import tensorflow as tf, tf_keras | |
| from official.vision.configs import common | |
| from official.vision.dataloaders import decoder | |
| from official.vision.dataloaders import parser | |
| from official.vision.ops import augment | |
| from official.vision.ops import preprocess_ops | |
| DEFAULT_IMAGE_FIELD_KEY = 'image/encoded' | |
| DEFAULT_LABEL_FIELD_KEY = 'image/class/label' | |
| class Decoder(decoder.Decoder): | |
| """A tf.Example decoder for classification task.""" | |
| def __init__(self, | |
| image_field_key: str = DEFAULT_IMAGE_FIELD_KEY, | |
| label_field_key: str = DEFAULT_LABEL_FIELD_KEY, | |
| is_multilabel: bool = False, | |
| keys_to_features: Optional[Dict[str, Any]] = None): | |
| if not keys_to_features: | |
| keys_to_features = { | |
| image_field_key: | |
| tf.io.FixedLenFeature((), tf.string, default_value=''), | |
| } | |
| if is_multilabel: | |
| keys_to_features.update( | |
| {label_field_key: tf.io.VarLenFeature(dtype=tf.int64)}) | |
| else: | |
| keys_to_features.update({ | |
| label_field_key: | |
| tf.io.FixedLenFeature((), tf.int64, default_value=-1) | |
| }) | |
| self._keys_to_features = keys_to_features | |
| def decode(self, serialized_example): | |
| return tf.io.parse_single_example(serialized_example, | |
| self._keys_to_features) | |
| class Parser(parser.Parser): | |
| """Parser to parse an image and its annotations into a dictionary of tensors.""" | |
| def __init__(self, | |
| output_size: List[int], | |
| num_classes: float, | |
| image_field_key: str = DEFAULT_IMAGE_FIELD_KEY, | |
| label_field_key: str = DEFAULT_LABEL_FIELD_KEY, | |
| decode_jpeg_only: bool = True, | |
| aug_rand_hflip: bool = True, | |
| aug_crop: Optional[bool] = True, | |
| aug_type: Optional[common.Augmentation] = None, | |
| color_jitter: float = 0., | |
| random_erasing: Optional[common.RandomErasing] = None, | |
| is_multilabel: bool = False, | |
| dtype: str = 'float32', | |
| crop_area_range: Optional[Tuple[float, float]] = (0.08, 1.0), | |
| center_crop_fraction: Optional[ | |
| float] = preprocess_ops.CENTER_CROP_FRACTION, | |
| tf_resize_method: str = 'bilinear', | |
| three_augment: bool = False): | |
| """Initializes parameters for parsing annotations in the dataset. | |
| Args: | |
| output_size: `Tensor` or `list` for [height, width] of output image. The | |
| output_size should be divided by the largest feature stride 2^max_level. | |
| num_classes: `float`, number of classes. | |
| image_field_key: `str`, the key name to encoded image or decoded image | |
| matrix in tf.Example. | |
| label_field_key: `str`, the key name to label in tf.Example. | |
| decode_jpeg_only: `bool`, if True, only JPEG format is decoded, this is | |
| faster than decoding other types. Default is True. | |
| aug_rand_hflip: `bool`, if True, augment training with random horizontal | |
| flip. | |
| aug_crop: `bool`, if True, perform random cropping during training and | |
| center crop during validation. | |
| aug_type: An optional Augmentation object to choose from AutoAugment and | |
| RandAugment. | |
| color_jitter: Magnitude of color jitter. If > 0, the value is used to | |
| generate random scale factor for brightness, contrast and saturation. | |
| See `preprocess_ops.color_jitter` for more details. | |
| random_erasing: if not None, augment input image by random erasing. See | |
| `augment.RandomErasing` for more details. | |
| is_multilabel: A `bool`, whether or not each example has multiple labels. | |
| dtype: `str`, cast output image in dtype. It can be 'float32', 'float16', | |
| or 'bfloat16'. | |
| crop_area_range: An optional `tuple` of (min_area, max_area) for image | |
| random crop function to constraint crop operation. The cropped areas | |
| of the image must contain a fraction of the input image within this | |
| range. The default area range is (0.08, 1.0). | |
| https://arxiv.org/abs/2204.07118. | |
| center_crop_fraction: center_crop_fraction. | |
| tf_resize_method: A `str`, interpolation method for resizing image. | |
| three_augment: A bool, whether to apply three augmentations. | |
| """ | |
| self._output_size = output_size | |
| self._aug_rand_hflip = aug_rand_hflip | |
| self._aug_crop = aug_crop | |
| self._num_classes = num_classes | |
| self._image_field_key = image_field_key | |
| if dtype == 'float32': | |
| self._dtype = tf.float32 | |
| elif dtype == 'float16': | |
| self._dtype = tf.float16 | |
| elif dtype == 'bfloat16': | |
| self._dtype = tf.bfloat16 | |
| else: | |
| raise ValueError('dtype {!r} is not supported!'.format(dtype)) | |
| if aug_type: | |
| if aug_type.type == 'autoaug': | |
| self._augmenter = augment.AutoAugment( | |
| augmentation_name=aug_type.autoaug.augmentation_name, | |
| cutout_const=aug_type.autoaug.cutout_const, | |
| translate_const=aug_type.autoaug.translate_const) | |
| elif aug_type.type == 'randaug': | |
| self._augmenter = augment.RandAugment( | |
| num_layers=aug_type.randaug.num_layers, | |
| magnitude=aug_type.randaug.magnitude, | |
| cutout_const=aug_type.randaug.cutout_const, | |
| translate_const=aug_type.randaug.translate_const, | |
| prob_to_apply=aug_type.randaug.prob_to_apply, | |
| exclude_ops=aug_type.randaug.exclude_ops) | |
| else: | |
| raise ValueError('Augmentation policy {} not supported.'.format( | |
| aug_type.type)) | |
| else: | |
| self._augmenter = None | |
| self._label_field_key = label_field_key | |
| self._color_jitter = color_jitter | |
| if random_erasing: | |
| self._random_erasing = augment.RandomErasing( | |
| probability=random_erasing.probability, | |
| min_area=random_erasing.min_area, | |
| max_area=random_erasing.max_area, | |
| min_aspect=random_erasing.min_aspect, | |
| max_aspect=random_erasing.max_aspect, | |
| min_count=random_erasing.min_count, | |
| max_count=random_erasing.max_count, | |
| trials=random_erasing.trials) | |
| else: | |
| self._random_erasing = None | |
| self._is_multilabel = is_multilabel | |
| self._decode_jpeg_only = decode_jpeg_only | |
| self._crop_area_range = crop_area_range | |
| self._center_crop_fraction = center_crop_fraction | |
| self._tf_resize_method = tf_resize_method | |
| self._three_augment = three_augment | |
| def _parse_train_data(self, decoded_tensors): | |
| """Parses data for training.""" | |
| image = self._parse_train_image(decoded_tensors) | |
| label = tf.cast(decoded_tensors[self._label_field_key], dtype=tf.int32) | |
| if self._is_multilabel: | |
| if isinstance(label, tf.sparse.SparseTensor): | |
| label = tf.sparse.to_dense(label) | |
| label = tf.reduce_sum(tf.one_hot(label, self._num_classes), axis=0) | |
| return image, label | |
| def _parse_eval_data(self, decoded_tensors): | |
| """Parses data for evaluation.""" | |
| image = self._parse_eval_image(decoded_tensors) | |
| label = tf.cast(decoded_tensors[self._label_field_key], dtype=tf.int32) | |
| if self._is_multilabel: | |
| if isinstance(label, tf.sparse.SparseTensor): | |
| label = tf.sparse.to_dense(label) | |
| label = tf.reduce_sum(tf.one_hot(label, self._num_classes), axis=0) | |
| return image, label | |
| def _parse_train_image(self, decoded_tensors): | |
| """Parses image data for training.""" | |
| image_bytes = decoded_tensors[self._image_field_key] | |
| require_decoding = ( | |
| not tf.is_tensor(image_bytes) or image_bytes.dtype == tf.dtypes.string | |
| ) | |
| if ( | |
| require_decoding | |
| and self._decode_jpeg_only | |
| and self._aug_crop | |
| ): | |
| image_shape = tf.image.extract_jpeg_shape(image_bytes) | |
| # Crops image. | |
| cropped_image = preprocess_ops.random_crop_image_v2( | |
| image_bytes, image_shape, area_range=self._crop_area_range) | |
| image = tf.cond( | |
| tf.reduce_all(tf.equal(tf.shape(cropped_image), image_shape)), | |
| lambda: preprocess_ops.center_crop_image_v2(image_bytes, image_shape), | |
| lambda: cropped_image) | |
| else: | |
| if require_decoding: | |
| # Decodes image. | |
| image = tf.io.decode_image(image_bytes, channels=3) | |
| image.set_shape([None, None, 3]) | |
| else: | |
| # Already decoded image matrix | |
| image = image_bytes | |
| # Crops image. | |
| if self._aug_crop: | |
| cropped_image = preprocess_ops.random_crop_image( | |
| image, area_range=self._crop_area_range) | |
| image = tf.cond( | |
| tf.reduce_all(tf.equal(tf.shape(cropped_image), tf.shape(image))), | |
| lambda: preprocess_ops.center_crop_image(image), | |
| lambda: cropped_image) | |
| if self._aug_rand_hflip: | |
| image = tf.image.random_flip_left_right(image) | |
| # Color jitter. | |
| if self._color_jitter > 0: | |
| image = preprocess_ops.color_jitter(image, self._color_jitter, | |
| self._color_jitter, | |
| self._color_jitter) | |
| # Resizes image. | |
| image = tf.image.resize( | |
| image, self._output_size, method=self._tf_resize_method) | |
| image.set_shape([self._output_size[0], self._output_size[1], 3]) | |
| # Apply autoaug or randaug. | |
| if self._augmenter is not None: | |
| image = self._augmenter.distort(image) | |
| # Three augmentation | |
| if self._three_augment: | |
| image = augment.AutoAugment( | |
| augmentation_name='deit3_three_augment', | |
| translate_const=20, | |
| ).distort(image) | |
| # Normalizes image with mean and std pixel values. | |
| image = preprocess_ops.normalize_image( | |
| image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB) | |
| # Random erasing after the image has been normalized | |
| if self._random_erasing is not None: | |
| image = self._random_erasing.distort(image) | |
| # Convert image to self._dtype. | |
| image = tf.image.convert_image_dtype(image, self._dtype) | |
| return image | |
| def _parse_eval_image(self, decoded_tensors): | |
| """Parses image data for evaluation.""" | |
| image_bytes = decoded_tensors[self._image_field_key] | |
| require_decoding = ( | |
| not tf.is_tensor(image_bytes) or image_bytes.dtype == tf.dtypes.string | |
| ) | |
| if ( | |
| require_decoding | |
| and self._decode_jpeg_only | |
| and self._aug_crop | |
| ): | |
| image_shape = tf.image.extract_jpeg_shape(image_bytes) | |
| # Center crops. | |
| image = preprocess_ops.center_crop_image_v2( | |
| image_bytes, image_shape, self._center_crop_fraction) | |
| else: | |
| if require_decoding: | |
| # Decodes image. | |
| image = tf.io.decode_image(image_bytes, channels=3) | |
| image.set_shape([None, None, 3]) | |
| else: | |
| # Already decoded image matrix | |
| image = image_bytes | |
| # Center crops. | |
| if self._aug_crop: | |
| image = preprocess_ops.center_crop_image( | |
| image, self._center_crop_fraction) | |
| image = tf.image.resize( | |
| image, self._output_size, method=self._tf_resize_method) | |
| image.set_shape([self._output_size[0], self._output_size[1], 3]) | |
| # Normalizes image with mean and std pixel values. | |
| image = preprocess_ops.normalize_image( | |
| image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB) | |
| # Convert image to self._dtype. | |
| image = tf.image.convert_image_dtype(image, self._dtype) | |
| return image | |
| def parse_train_image(self, decoded_tensors: Dict[str, | |
| tf.Tensor]) -> tf.Tensor: | |
| """Public interface for parsing image data for training.""" | |
| return self._parse_train_image(decoded_tensors) | |
| def inference_fn(cls, | |
| image: tf.Tensor, | |
| input_image_size: List[int], | |
| num_channels: int = 3) -> tf.Tensor: | |
| """Builds image model inputs for serving.""" | |
| image = tf.cast(image, dtype=tf.float32) | |
| image = preprocess_ops.center_crop_image(image) | |
| image = tf.image.resize( | |
| image, input_image_size, method=tf.image.ResizeMethod.BILINEAR) | |
| # Normalizes image with mean and std pixel values. | |
| image = preprocess_ops.normalize_image( | |
| image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB) | |
| image.set_shape(input_image_size + [num_channels]) | |
| return image | |