File size: 3,486 Bytes
206d8b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
scripts/export_tflite.py
Quantize a SavedModel to an INT8 TFLite model for deployment.

Usage:
    python scripts/export_tflite.py \
        --saved_model models/waste_classifier_v1 \
        --data_dir data/processed \
        --output models/model.tflite
"""

import argparse
import os
import time

import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

PREPROCESS_INPUT = tf.keras.applications.mobilenet_v2.preprocess_input


def representative_dataset(data_dir: str, n_samples: int = 200):
    """Yield calibration samples for INT8 quantization."""
    datagen = ImageDataGenerator(preprocessing_function=PREPROCESS_INPUT)
    generator = datagen.flow_from_directory(
        os.path.join(data_dir, "train"),
        target_size=(224, 224),
        batch_size=1,
        class_mode=None,
        shuffle=True,
        seed=42,
    )
    for index, image in enumerate(generator):
        if index >= n_samples:
            break
        yield [image.astype(np.float32)]


def export(saved_model_path: str, data_dir: str, output_path: str):
    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)

    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = lambda: representative_dataset(data_dir)
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.uint8
    converter.inference_output_type = tf.uint8

    tflite_model = converter.convert()

    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
    with open(output_path, "wb") as file:
        file.write(tflite_model)

    original_mb = sum(
        os.path.getsize(os.path.join(root, filename))
        for root, _, files in os.walk(saved_model_path)
        for filename in files
    ) / 1024 / 1024
    tflite_mb = os.path.getsize(output_path) / 1024 / 1024

    print(f"Original SavedModel : {original_mb:.1f} MB")
    print(f"TFLite INT8 model   : {tflite_mb:.1f} MB")
    print(f"Reduction           : {(1 - tflite_mb / original_mb) * 100:.0f}%")
    print(f"Exported -> {output_path}")


def benchmark(model_path: str, n_runs: int = 50):
    """Run a quick CPU latency benchmark."""
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()[0]

    dummy = np.random.randint(0, 255, (1, 224, 224, 3), dtype=np.uint8)
    interpreter.set_tensor(input_details["index"], dummy)
    interpreter.invoke()

    times = []
    for _ in range(n_runs):
        start = time.perf_counter()
        interpreter.set_tensor(input_details["index"], dummy)
        interpreter.invoke()
        times.append((time.perf_counter() - start) * 1000)

    times.sort()
    print(f"\nLatency over {n_runs} runs (CPU):")
    print(f"  Median : {np.median(times):.1f} ms")
    print(f"  p95    : {times[int(0.95 * n_runs)]:.1f} ms")
    print(f"  Min    : {times[0]:.1f} ms")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--saved_model", default="models/waste_classifier_v1")
    parser.add_argument("--data_dir", default="data/processed")
    parser.add_argument("--output", default="models/model.tflite")
    parser.add_argument("--benchmark", action="store_true")
    args = parser.parse_args()

    export(args.saved_model, args.data_dir, args.output)
    if args.benchmark:
        benchmark(args.output)