CoolWasteAI / scripts /export_tflite.py
Celvin
first commit
206d8b5
"""
scripts/export_tflite.py
Quantize a SavedModel to an INT8 TFLite model for deployment.
Usage:
python scripts/export_tflite.py \
--saved_model models/waste_classifier_v1 \
--data_dir data/processed \
--output models/model.tflite
"""
import argparse
import os
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
PREPROCESS_INPUT = tf.keras.applications.mobilenet_v2.preprocess_input
def representative_dataset(data_dir: str, n_samples: int = 200):
"""Yield calibration samples for INT8 quantization."""
datagen = ImageDataGenerator(preprocessing_function=PREPROCESS_INPUT)
generator = datagen.flow_from_directory(
os.path.join(data_dir, "train"),
target_size=(224, 224),
batch_size=1,
class_mode=None,
shuffle=True,
seed=42,
)
for index, image in enumerate(generator):
if index >= n_samples:
break
yield [image.astype(np.float32)]
def export(saved_model_path: str, data_dir: str, output_path: str):
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = lambda: representative_dataset(data_dir)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
with open(output_path, "wb") as file:
file.write(tflite_model)
original_mb = sum(
os.path.getsize(os.path.join(root, filename))
for root, _, files in os.walk(saved_model_path)
for filename in files
) / 1024 / 1024
tflite_mb = os.path.getsize(output_path) / 1024 / 1024
print(f"Original SavedModel : {original_mb:.1f} MB")
print(f"TFLite INT8 model : {tflite_mb:.1f} MB")
print(f"Reduction : {(1 - tflite_mb / original_mb) * 100:.0f}%")
print(f"Exported -> {output_path}")
def benchmark(model_path: str, n_runs: int = 50):
"""Run a quick CPU latency benchmark."""
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()[0]
dummy = np.random.randint(0, 255, (1, 224, 224, 3), dtype=np.uint8)
interpreter.set_tensor(input_details["index"], dummy)
interpreter.invoke()
times = []
for _ in range(n_runs):
start = time.perf_counter()
interpreter.set_tensor(input_details["index"], dummy)
interpreter.invoke()
times.append((time.perf_counter() - start) * 1000)
times.sort()
print(f"\nLatency over {n_runs} runs (CPU):")
print(f" Median : {np.median(times):.1f} ms")
print(f" p95 : {times[int(0.95 * n_runs)]:.1f} ms")
print(f" Min : {times[0]:.1f} ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--saved_model", default="models/waste_classifier_v1")
parser.add_argument("--data_dir", default="data/processed")
parser.add_argument("--output", default="models/model.tflite")
parser.add_argument("--benchmark", action="store_true")
args = parser.parse_args()
export(args.saved_model, args.data_dir, args.output)
if args.benchmark:
benchmark(args.output)