|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Image classification input and model functions for serving/inference."""
|
|
|
| import tensorflow as tf, tf_keras
|
|
|
| from official.vision.modeling import factory
|
| from official.vision.ops import preprocess_ops
|
| from official.vision.serving import export_base
|
|
|
|
|
| class ClassificationModule(export_base.ExportModule):
|
| """classification Module."""
|
|
|
| def _build_model(self):
|
| input_specs = tf_keras.layers.InputSpec(
|
| shape=[self._batch_size] + self._input_image_size + [3])
|
|
|
| return factory.build_classification_model(
|
| input_specs=input_specs,
|
| model_config=self.params.task.model,
|
| l2_regularizer=None)
|
|
|
| def _crop_and_resize(self, image):
|
| if self.params.task.train_data.aug_crop:
|
| image = preprocess_ops.center_crop_image(image)
|
|
|
| image = tf.image.resize(
|
| image, self._input_image_size, method=tf.image.ResizeMethod.BILINEAR)
|
|
|
| image = tf.reshape(
|
| image, [self._input_image_size[0], self._input_image_size[1], 3])
|
|
|
| return image
|
|
|
| def _build_inputs(self, image):
|
| """Builds classification model inputs for serving."""
|
|
|
| if isinstance(image, tf.RaggedTensor):
|
| image = image.to_tensor()
|
| image = tf.cast(image, dtype=tf.float32)
|
|
|
|
|
| if not (
|
| self._input_type in ['tf_example', 'image_bytes']
|
| and len(self._input_image_size) == 2):
|
| image = self._crop_and_resize(image)
|
|
|
|
|
| image = preprocess_ops.normalize_image(
|
| image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB)
|
| return image
|
|
|
| def _decode_image(self, encoded_image_bytes: str) -> tf.Tensor:
|
| """Decodes an image bytes to an image tensor.
|
|
|
| Use `tf.image.decode_image` to decode an image if input is expected to be 2D
|
| image; otherwise use `tf.io.decode_raw` to convert the raw bytes to tensor
|
| and reshape it to desire shape.
|
|
|
| Args:
|
| encoded_image_bytes: An encoded image string to be decoded.
|
|
|
| Returns:
|
| A decoded image tensor.
|
| """
|
| if len(self._input_image_size) == 2:
|
|
|
| image_tensor = tf.image.decode_image(
|
| encoded_image_bytes, channels=self._num_channels
|
| )
|
| image_tensor.set_shape((None, None, self._num_channels))
|
|
|
|
|
| image_tensor = tf.cast(image_tensor, dtype=tf.float32)
|
| image_tensor = self._crop_and_resize(image_tensor)
|
| image_tensor = tf.cast(image_tensor, tf.uint8)
|
| return image_tensor
|
| else:
|
|
|
| image_tensor = tf.io.decode_raw(encoded_image_bytes, out_type=tf.uint8)
|
| image_tensor = tf.reshape(
|
| image_tensor, self._input_image_size + [self._num_channels]
|
| )
|
| return image_tensor
|
|
|
| def serve(self, images):
|
| """Cast image to float and run inference.
|
|
|
| Args:
|
| images: uint8 Tensor of shape [batch_size, None, None, 3]
|
| Returns:
|
| Tensor holding classification output logits.
|
| """
|
|
|
|
|
| if self._input_type != 'tflite':
|
| with tf.device('cpu:0'):
|
| images = tf.nest.map_structure(
|
| tf.identity,
|
| tf.map_fn(
|
| self._build_inputs,
|
| elems=images,
|
| fn_output_signature=tf.TensorSpec(
|
| shape=self._input_image_size + [3], dtype=tf.float32),
|
| parallel_iterations=32))
|
|
|
| logits = self.inference_step(images)
|
| if self.params.task.train_data.is_multilabel:
|
| probs = tf.math.sigmoid(logits)
|
| else:
|
| probs = tf.nn.softmax(logits)
|
|
|
| return {'logits': logits, 'probs': probs}
|
|
|