""" scripts/export_tflite.py Quantize a SavedModel to an INT8 TFLite model for deployment. Usage: python scripts/export_tflite.py \ --saved_model models/waste_classifier_v1 \ --data_dir data/processed \ --output models/model.tflite """ import argparse import os import time import numpy as np import tensorflow as tf from tensorflow.keras.preprocessing.image import ImageDataGenerator PREPROCESS_INPUT = tf.keras.applications.mobilenet_v2.preprocess_input def representative_dataset(data_dir: str, n_samples: int = 200): """Yield calibration samples for INT8 quantization.""" datagen = ImageDataGenerator(preprocessing_function=PREPROCESS_INPUT) generator = datagen.flow_from_directory( os.path.join(data_dir, "train"), target_size=(224, 224), batch_size=1, class_mode=None, shuffle=True, seed=42, ) for index, image in enumerate(generator): if index >= n_samples: break yield [image.astype(np.float32)] def export(saved_model_path: str, data_dir: str, output_path: str): converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = lambda: representative_dataset(data_dir) converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.uint8 converter.inference_output_type = tf.uint8 tflite_model = converter.convert() os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) with open(output_path, "wb") as file: file.write(tflite_model) original_mb = sum( os.path.getsize(os.path.join(root, filename)) for root, _, files in os.walk(saved_model_path) for filename in files ) / 1024 / 1024 tflite_mb = os.path.getsize(output_path) / 1024 / 1024 print(f"Original SavedModel : {original_mb:.1f} MB") print(f"TFLite INT8 model : {tflite_mb:.1f} MB") print(f"Reduction : {(1 - tflite_mb / original_mb) * 100:.0f}%") print(f"Exported -> {output_path}") def benchmark(model_path: str, n_runs: int = 50): """Run a quick CPU latency benchmark.""" interpreter = tf.lite.Interpreter(model_path=model_path) interpreter.allocate_tensors() input_details = interpreter.get_input_details()[0] dummy = np.random.randint(0, 255, (1, 224, 224, 3), dtype=np.uint8) interpreter.set_tensor(input_details["index"], dummy) interpreter.invoke() times = [] for _ in range(n_runs): start = time.perf_counter() interpreter.set_tensor(input_details["index"], dummy) interpreter.invoke() times.append((time.perf_counter() - start) * 1000) times.sort() print(f"\nLatency over {n_runs} runs (CPU):") print(f" Median : {np.median(times):.1f} ms") print(f" p95 : {times[int(0.95 * n_runs)]:.1f} ms") print(f" Min : {times[0]:.1f} ms") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--saved_model", default="models/waste_classifier_v1") parser.add_argument("--data_dir", default="data/processed") parser.add_argument("--output", default="models/model.tflite") parser.add_argument("--benchmark", action="store_true") args = parser.parse_args() export(args.saved_model, args.data_dir, args.output) if args.benchmark: benchmark(args.output)