|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import pathlib
|
| import numpy as np
|
| import tensorflow as tf
|
| import tqdm
|
| import os
|
| from hydra.core.hydra_config import HydraConfig
|
| from omegaconf import DictConfig
|
|
|
|
|
| from common.optimization import model_formatting_ptq_per_tensor
|
|
|
|
|
|
|
| class TFLitePTQQuantizer:
|
| """
|
| A class to handle TensorFlow Lite Post-Training Quantization (PTQ).
|
|
|
| Args:
|
| cfg (DictConfig): Configuration object for quantization.
|
| model (tf.keras.Model): The TensorFlow model to quantize.
|
| dataloaders (dict): Dictionary containing datasets for quantization and testing.
|
| """
|
| def __init__(self, cfg: DictConfig = None, model: object = None,
|
| dataloaders: dict = None):
|
| self.cfg = cfg
|
| self.model = model
|
| self.quantization_ds = dataloaders['quantization']
|
| self.output_dir = HydraConfig.get().runtime.output_dir
|
| self.export_dir = cfg.quantization.export_dir
|
| self.quantized_model = None
|
|
|
| def _representative_data_gen(self, input_shape):
|
| """
|
| Generates representative data for quantization.
|
|
|
| Args:
|
| input_shape (tuple): Shape of the model input.
|
|
|
| Yields:
|
| List[np.ndarray]: Representative data samples.
|
| """
|
| if not self.quantization_ds:
|
|
|
| for _ in tqdm.tqdm(range(5)):
|
| data = np.random.rand(1, input_shape[0], input_shape[1], input_shape[2])
|
| yield [data.astype(np.float32)]
|
| else:
|
|
|
| for images, labels in tqdm.tqdm(self.quantization_ds, total=len(self.quantization_ds)):
|
| for image in images:
|
| image = tf.cast(image, dtype=tf.float32)
|
| image = tf.expand_dims(image, 0)
|
| yield [image]
|
|
|
| def _run_quantization(self):
|
| """
|
| Runs the quantization process and saves the quantized model.
|
| """
|
|
|
| input_shape = self.model.input_shape[1:]
|
|
|
| converter = tf.lite.TFLiteConverter.from_keras_model(self.model)
|
|
|
|
|
| q_input = self.cfg.quantization.quantization_input_type
|
| q_output = self.cfg.quantization.quantization_output_type
|
| if q_input == 'int8':
|
| converter.inference_input_type = tf.int8
|
| elif q_input == 'uint8':
|
| converter.inference_input_type = tf.uint8
|
| if q_output == 'int8':
|
| converter.inference_output_type = tf.int8
|
| elif q_output == 'uint8':
|
| converter.inference_output_type = tf.uint8
|
|
|
|
|
| converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| converter.representative_dataset = lambda: self._representative_data_gen(input_shape)
|
|
|
|
|
| if self.cfg.quantization.granularity == 'per_tensor':
|
| converter._experimental_disable_per_channel = True
|
|
|
|
|
| tflite_model_quantized = converter.convert()
|
|
|
|
|
| tflite_models_dir = pathlib.Path(os.path.join(self.output_dir, f"{self.export_dir}/"))
|
| tflite_models_dir.mkdir(exist_ok=True, parents=True)
|
| tflite_model_path = tflite_models_dir / "quantized_model.tflite"
|
| tflite_model_path.write_bytes(tflite_model_quantized)
|
|
|
|
|
| interpreter = tf.lite.Interpreter(model_path=str(tflite_model_path))
|
| interpreter.allocate_tensors()
|
| setattr(interpreter, 'model_path', str(tflite_model_path))
|
| self.quantized_model = interpreter
|
|
|
| def _prepare_quantization(self):
|
| """
|
| Prepares the model for quantization by applying optimizations if necessary.
|
| """
|
|
|
| if not isinstance(self.model, tf.keras.Model):
|
| raise ValueError(f"Unsupported model format: {type(self.model)}. ")
|
|
|
|
|
| if self.cfg.quantization.granularity == 'per_tensor' and self.cfg.quantization.optimize:
|
| print("[INFO] : Optimizing the model for improved per_tensor quantization...")
|
| self.model = model_formatting_ptq_per_tensor(model_origin=self.model)
|
| models_dir = pathlib.Path(os.path.join(self.output_dir, f"{self.export_dir}/"))
|
| models_dir.mkdir(exist_ok=True, parents=True)
|
| model_path = models_dir / "optimized_model.keras"
|
| self.model.save(model_path)
|
|
|
| def quantize(self):
|
| """
|
| Executes the quantization process.
|
|
|
| Returns:
|
| tf.lite.Interpreter: The quantized TFLite model as an Interpreter object.
|
| """
|
| print("[INFO] : Quantizing the model ... This might take few minutes ...")
|
| self._prepare_quantization()
|
| self._run_quantization()
|
| print('[INFO] : Quantization complete.')
|
| return self.quantized_model
|
|
|
|
|