File size: 6,236 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#  *---------------------------------------------------------------------------------------------*/
#  * Copyright (c) 2022 STMicroelectronics.
#  * All rights reserved.
#  *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/

# Import necessary libraries
import pathlib
import numpy as np
import tensorflow as tf
import tqdm
import os
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig

# Import utility functions and modules
from common.optimization import model_formatting_ptq_per_tensor


# Define a class for TensorFlow Lite Post-Training Quantization (PTQ)
class TFLitePTQQuantizer:
    """

    A class to handle TensorFlow Lite Post-Training Quantization (PTQ).



    Args:

        cfg (DictConfig): Configuration object for quantization.

        model (tf.keras.Model): The TensorFlow model to quantize.

        dataloaders (dict): Dictionary containing datasets for quantization and testing.

    """
    def __init__(self, cfg: DictConfig = None, model: object = None, 

                 dataloaders: dict = None):
        self.cfg = cfg
        self.model = model
        self.quantization_ds = dataloaders['quantization']
        self.output_dir = HydraConfig.get().runtime.output_dir
        self.export_dir = cfg.quantization.export_dir
        self.quantized_model = None

    def _representative_data_gen(self, input_shape):
        """

        Generates representative data for quantization.



        Args:

            input_shape (tuple): Shape of the model input.



        Yields:

            List[np.ndarray]: Representative data samples.

        """
        if not self.quantization_ds:
            # If no dataset is provided, generate random data
            for _ in tqdm.tqdm(range(5)):
                data = np.random.rand(1, input_shape[0], input_shape[1], input_shape[2])
                yield [data.astype(np.float32)]
        else:
            # Use the provided dataset for representative data
            for images, labels in tqdm.tqdm(self.quantization_ds, total=len(self.quantization_ds)):
                for image in images:
                    image = tf.cast(image, dtype=tf.float32)
                    image = tf.expand_dims(image, 0) # Add batch dimension
                    yield [image]

    def _run_quantization(self):
        """

        Runs the quantization process and saves the quantized model.

        """
        # Get the input shape of the model
        input_shape = self.model.input_shape[1:] #tuple(self.model.input.shape[1:])
        # Create a TFLite converter from the Keras model
        converter = tf.lite.TFLiteConverter.from_keras_model(self.model)

        # Set input and output quantization types
        q_input = self.cfg.quantization.quantization_input_type
        q_output = self.cfg.quantization.quantization_output_type
        if q_input == 'int8':
            converter.inference_input_type = tf.int8
        elif q_input == 'uint8':
            converter.inference_input_type = tf.uint8
        if q_output == 'int8':
            converter.inference_output_type = tf.int8
        elif q_output == 'uint8':
            converter.inference_output_type = tf.uint8

        # Enable default optimizations and set the representative dataset
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.representative_dataset = lambda: self._representative_data_gen(input_shape)

        # Set quantization granularity (per-tensor or per-channel)
        if self.cfg.quantization.granularity == 'per_tensor':
            converter._experimental_disable_per_channel = True

        # Convert the model to TFLite format
        tflite_model_quantized = converter.convert()

        # Save the quantized model to the specified directory
        tflite_models_dir = pathlib.Path(os.path.join(self.output_dir, f"{self.export_dir}/"))
        tflite_models_dir.mkdir(exist_ok=True, parents=True)
        tflite_model_path = tflite_models_dir / "quantized_model.tflite"
        tflite_model_path.write_bytes(tflite_model_quantized)

        # Load the quantized model as a TFLite Interpreter
        interpreter = tf.lite.Interpreter(model_path=str(tflite_model_path))
        interpreter.allocate_tensors()
        setattr(interpreter, 'model_path', str(tflite_model_path)) # Add model path as an attribute
        self.quantized_model = interpreter

    def _prepare_quantization(self):
        """

        Prepares the model for quantization by applying optimizations if necessary.

        """
        # Ensure the model is a TensorFlow Keras model
        if not isinstance(self.model, tf.keras.Model):
            raise ValueError(f"Unsupported model format: {type(self.model)}. ")
        
        # Apply optimizations for per-tensor quantization if specified
        if self.cfg.quantization.granularity == 'per_tensor' and self.cfg.quantization.optimize:
            print("[INFO] : Optimizing the model for improved per_tensor quantization...")
            self.model = model_formatting_ptq_per_tensor(model_origin=self.model)
            models_dir = pathlib.Path(os.path.join(self.output_dir, f"{self.export_dir}/"))
            models_dir.mkdir(exist_ok=True, parents=True)
            model_path = models_dir / "optimized_model.keras"
            self.model.save(model_path)

    def quantize(self):
        """

        Executes the quantization process.



        Returns:

            tf.lite.Interpreter: The quantized TFLite model as an Interpreter object.

        """
        print("[INFO] : Quantizing the model ... This might take few minutes ...")
        self._prepare_quantization()  # Prepare the model for quantization
        self._run_quantization()      # Run the quantization process
        print('[INFO] : Quantization complete.')
        return self.quantized_model   # Return the quantized model