Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Library to facilitate TFLite model conversion.""" | |
| import functools | |
| from typing import Iterator, List, Optional | |
| from absl import logging | |
| import tensorflow as tf, tf_keras | |
| from official.core import base_task | |
| from official.core import config_definitions as cfg | |
| from official.vision import configs | |
| from official.vision import tasks | |
| def create_representative_dataset( | |
| params: cfg.ExperimentConfig, | |
| task: Optional[base_task.Task] = None) -> tf.data.Dataset: | |
| """Creates a tf.data.Dataset to load images for representative dataset. | |
| Args: | |
| params: An ExperimentConfig. | |
| task: An optional task instance. If it is None, task will be built according | |
| to the task type in params. | |
| Returns: | |
| A tf.data.Dataset instance. | |
| Raises: | |
| ValueError: If task is not supported. | |
| """ | |
| if task is None: | |
| if isinstance(params.task, | |
| configs.image_classification.ImageClassificationTask): | |
| task = tasks.image_classification.ImageClassificationTask(params.task) | |
| elif isinstance(params.task, configs.retinanet.RetinaNetTask): | |
| task = tasks.retinanet.RetinaNetTask(params.task) | |
| elif isinstance(params.task, configs.maskrcnn.MaskRCNNTask): | |
| task = tasks.maskrcnn.MaskRCNNTask(params.task) | |
| elif isinstance(params.task, | |
| configs.semantic_segmentation.SemanticSegmentationTask): | |
| task = tasks.semantic_segmentation.SemanticSegmentationTask(params.task) | |
| else: | |
| raise ValueError('Task {} not supported.'.format(type(params.task))) | |
| # Ensure batch size is 1 for TFLite model. | |
| params.task.train_data.global_batch_size = 1 | |
| params.task.train_data.dtype = 'float32' | |
| logging.info('Task config: %s', params.task.as_dict()) | |
| return task.build_inputs(params=params.task.train_data) | |
| def representative_dataset( | |
| params: cfg.ExperimentConfig, | |
| task: Optional[base_task.Task] = None, | |
| calibration_steps: int = 2000) -> Iterator[List[tf.Tensor]]: | |
| """"Creates representative dataset for input calibration. | |
| Args: | |
| params: An ExperimentConfig. | |
| task: An optional task instance. If it is None, task will be built according | |
| to the task type in params. | |
| calibration_steps: The steps to do calibration. | |
| Yields: | |
| An input image tensor. | |
| """ | |
| dataset = create_representative_dataset(params=params, task=task) | |
| for image, _ in dataset.take(calibration_steps): | |
| # Skip images that do not have 3 channels. | |
| if image.shape[-1] != 3: | |
| continue | |
| yield [image] | |
| def convert_tflite_model( | |
| saved_model_dir: Optional[str] = None, | |
| concrete_function: Optional[tf.types.experimental.ConcreteFunction] = None, | |
| model: Optional[tf.Module] = None, | |
| quant_type: Optional[str] = None, | |
| params: Optional[cfg.ExperimentConfig] = None, | |
| task: Optional[base_task.Task] = None, | |
| calibration_steps: Optional[int] = 2000, | |
| denylisted_ops: Optional[List[str]] = None, | |
| ) -> 'bytes': | |
| """Converts and returns a TFLite model. | |
| Args: | |
| saved_model_dir: The directory to the SavedModel. | |
| concrete_function: An optional concrete function to be exported. | |
| model: An optional tf_keras.Model instance. If both `saved_model_dir` and | |
| `concrete_function` are not available, convert this model to TFLite. | |
| quant_type: The post training quantization (PTQ) method. It can be one of | |
| `default` (dynamic range), `fp16` (float16), `int8` (integer wih float | |
| fallback), `int8_full` (integer only) and None (no quantization). | |
| params: An optional ExperimentConfig to load and preprocess input images to | |
| do calibration for integer quantization. | |
| task: An optional task instance. If it is None, task will be built according | |
| to the task type in params. | |
| calibration_steps: The steps to do calibration. | |
| denylisted_ops: A list of strings containing ops that are excluded from | |
| integer quantization. | |
| Returns: | |
| A converted TFLite model with optional PTQ. | |
| Raises: | |
| ValueError: If `representative_dataset_path` is not present if integer | |
| quantization is requested, or `saved_model_dir`, `concrete_function` or | |
| `model` are not provided. | |
| """ | |
| if saved_model_dir: | |
| converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) | |
| elif concrete_function is not None: | |
| converter = tf.lite.TFLiteConverter.from_concrete_functions( | |
| [concrete_function] | |
| ) | |
| elif model is not None: | |
| converter = tf.lite.TFLiteConverter.from_keras_model(model) | |
| else: | |
| raise ValueError( | |
| '`saved_model_dir`, `model` or `concrete_function` must be specified.' | |
| ) | |
| if quant_type: | |
| if quant_type.startswith('int8'): | |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
| converter.representative_dataset = functools.partial( | |
| representative_dataset, | |
| params=params, | |
| task=task, | |
| calibration_steps=calibration_steps) | |
| if quant_type.startswith('int8_full'): | |
| converter.target_spec.supported_ops = [ | |
| tf.lite.OpsSet.TFLITE_BUILTINS_INT8 | |
| ] | |
| if quant_type == 'int8_full': | |
| converter.inference_input_type = tf.uint8 | |
| converter.inference_output_type = tf.uint8 | |
| if quant_type == 'int8_full_int8_io': | |
| converter.inference_input_type = tf.int8 | |
| converter.inference_output_type = tf.int8 | |
| if denylisted_ops: | |
| debug_options = tf.lite.experimental.QuantizationDebugOptions( | |
| denylisted_ops=denylisted_ops) | |
| debugger = tf.lite.experimental.QuantizationDebugger( | |
| converter=converter, | |
| debug_dataset=functools.partial( | |
| representative_dataset, | |
| params=params, | |
| calibration_steps=calibration_steps), | |
| debug_options=debug_options) | |
| debugger.run() | |
| return debugger.get_nondebug_quantized_model() | |
| elif quant_type == 'uint8': | |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
| converter.default_ranges_stats = (-10, 10) | |
| converter.inference_type = tf.uint8 | |
| converter.quantized_input_stats = {'input_placeholder': (0., 1.)} | |
| elif quant_type == 'fp16': | |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
| converter.target_spec.supported_types = [tf.float16] | |
| elif quant_type in ('default', 'qat_fp32_io'): | |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
| elif quant_type == 'qat': | |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
| converter.inference_input_type = tf.uint8 # or tf.int8 | |
| converter.inference_output_type = tf.uint8 # or tf.int8 | |
| else: | |
| raise ValueError(f'quantization type {quant_type} is not supported.') | |
| return converter.convert() | |