Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Factory for vision export modules.""" | |
| from typing import List, Optional | |
| import tensorflow as tf, tf_keras | |
| from official.core import config_definitions as cfg | |
| from official.vision import configs | |
| from official.vision.dataloaders import classification_input | |
| from official.vision.modeling import factory | |
| from official.vision.serving import export_base_v2 as export_base | |
| from official.vision.serving import export_utils | |
| def create_classification_export_module(params: cfg.ExperimentConfig, | |
| input_type: str, | |
| batch_size: int, | |
| input_image_size: List[int], | |
| num_channels: int = 3): | |
| """Creats classification export module.""" | |
| input_signature = export_utils.get_image_input_signatures( | |
| input_type, batch_size, input_image_size, num_channels) | |
| input_specs = tf_keras.layers.InputSpec( | |
| shape=[batch_size] + input_image_size + [num_channels]) | |
| model = factory.build_classification_model( | |
| input_specs=input_specs, | |
| model_config=params.task.model, | |
| l2_regularizer=None) | |
| def preprocess_fn(inputs): | |
| image_tensor = export_utils.parse_image(inputs, input_type, | |
| input_image_size, num_channels) | |
| # If input_type is `tflite`, do not apply image preprocessing. | |
| if input_type == 'tflite': | |
| return image_tensor | |
| def preprocess_image_fn(inputs): | |
| return classification_input.Parser.inference_fn( | |
| inputs, input_image_size, num_channels) | |
| images = tf.map_fn( | |
| preprocess_image_fn, elems=image_tensor, | |
| fn_output_signature=tf.TensorSpec( | |
| shape=input_image_size + [num_channels], | |
| dtype=tf.float32)) | |
| return images | |
| def postprocess_fn(logits): | |
| probs = tf.nn.softmax(logits) | |
| return {'logits': logits, 'probs': probs} | |
| export_module = export_base.ExportModule(params, | |
| model=model, | |
| input_signature=input_signature, | |
| preprocessor=preprocess_fn, | |
| postprocessor=postprocess_fn) | |
| return export_module | |
| def get_export_module(params: cfg.ExperimentConfig, | |
| input_type: str, | |
| batch_size: Optional[int], | |
| input_image_size: List[int], | |
| num_channels: int = 3) -> export_base.ExportModule: | |
| """Factory for export modules.""" | |
| if isinstance(params.task, | |
| configs.image_classification.ImageClassificationTask): | |
| export_module = create_classification_export_module( | |
| params, input_type, batch_size, input_image_size, num_channels) | |
| else: | |
| raise ValueError('Export module not implemented for {} task.'.format( | |
| type(params.task))) | |
| return export_module | |