|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Factory for vision export modules."""
|
|
|
| from typing import List, Optional
|
|
|
| import tensorflow as tf, tf_keras
|
|
|
| from official.core import config_definitions as cfg
|
| from official.vision import configs
|
| from official.vision.dataloaders import classification_input
|
| from official.vision.modeling import factory
|
| from official.vision.serving import export_base_v2 as export_base
|
| from official.vision.serving import export_utils
|
|
|
|
|
| def create_classification_export_module(params: cfg.ExperimentConfig,
|
| input_type: str,
|
| batch_size: int,
|
| input_image_size: List[int],
|
| num_channels: int = 3):
|
| """Creats classification export module."""
|
| input_signature = export_utils.get_image_input_signatures(
|
| input_type, batch_size, input_image_size, num_channels)
|
| input_specs = tf_keras.layers.InputSpec(
|
| shape=[batch_size] + input_image_size + [num_channels])
|
|
|
| model = factory.build_classification_model(
|
| input_specs=input_specs,
|
| model_config=params.task.model,
|
| l2_regularizer=None)
|
|
|
| def preprocess_fn(inputs):
|
| image_tensor = export_utils.parse_image(inputs, input_type,
|
| input_image_size, num_channels)
|
|
|
| if input_type == 'tflite':
|
| return image_tensor
|
|
|
| def preprocess_image_fn(inputs):
|
| return classification_input.Parser.inference_fn(
|
| inputs, input_image_size, num_channels)
|
|
|
| images = tf.map_fn(
|
| preprocess_image_fn, elems=image_tensor,
|
| fn_output_signature=tf.TensorSpec(
|
| shape=input_image_size + [num_channels],
|
| dtype=tf.float32))
|
|
|
| return images
|
|
|
| def postprocess_fn(logits):
|
| probs = tf.nn.softmax(logits)
|
| return {'logits': logits, 'probs': probs}
|
|
|
| export_module = export_base.ExportModule(params,
|
| model=model,
|
| input_signature=input_signature,
|
| preprocessor=preprocess_fn,
|
| postprocessor=postprocess_fn)
|
| return export_module
|
|
|
|
|
| def get_export_module(params: cfg.ExperimentConfig,
|
| input_type: str,
|
| batch_size: Optional[int],
|
| input_image_size: List[int],
|
| num_channels: int = 3) -> export_base.ExportModule:
|
| """Factory for export modules."""
|
| if isinstance(params.task,
|
| configs.image_classification.ImageClassificationTask):
|
| export_module = create_classification_export_module(
|
| params, input_type, batch_size, input_image_size, num_channels)
|
| else:
|
| raise ValueError('Export module not implemented for {} task.'.format(
|
| type(params.task)))
|
| return export_module
|
|
|