Spaces:
Running
Running
| """ | |
| scripts/export_tflite.py | |
| Quantize a SavedModel to an INT8 TFLite model for deployment. | |
| Usage: | |
| python scripts/export_tflite.py \ | |
| --saved_model models/waste_classifier_v1 \ | |
| --data_dir data/processed \ | |
| --output models/model.tflite | |
| """ | |
| import argparse | |
| import os | |
| import time | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
| PREPROCESS_INPUT = tf.keras.applications.mobilenet_v2.preprocess_input | |
| def representative_dataset(data_dir: str, n_samples: int = 200): | |
| """Yield calibration samples for INT8 quantization.""" | |
| datagen = ImageDataGenerator(preprocessing_function=PREPROCESS_INPUT) | |
| generator = datagen.flow_from_directory( | |
| os.path.join(data_dir, "train"), | |
| target_size=(224, 224), | |
| batch_size=1, | |
| class_mode=None, | |
| shuffle=True, | |
| seed=42, | |
| ) | |
| for index, image in enumerate(generator): | |
| if index >= n_samples: | |
| break | |
| yield [image.astype(np.float32)] | |
| def export(saved_model_path: str, data_dir: str, output_path: str): | |
| converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path) | |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
| converter.representative_dataset = lambda: representative_dataset(data_dir) | |
| converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] | |
| converter.inference_input_type = tf.uint8 | |
| converter.inference_output_type = tf.uint8 | |
| tflite_model = converter.convert() | |
| os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) | |
| with open(output_path, "wb") as file: | |
| file.write(tflite_model) | |
| original_mb = sum( | |
| os.path.getsize(os.path.join(root, filename)) | |
| for root, _, files in os.walk(saved_model_path) | |
| for filename in files | |
| ) / 1024 / 1024 | |
| tflite_mb = os.path.getsize(output_path) / 1024 / 1024 | |
| print(f"Original SavedModel : {original_mb:.1f} MB") | |
| print(f"TFLite INT8 model : {tflite_mb:.1f} MB") | |
| print(f"Reduction : {(1 - tflite_mb / original_mb) * 100:.0f}%") | |
| print(f"Exported -> {output_path}") | |
| def benchmark(model_path: str, n_runs: int = 50): | |
| """Run a quick CPU latency benchmark.""" | |
| interpreter = tf.lite.Interpreter(model_path=model_path) | |
| interpreter.allocate_tensors() | |
| input_details = interpreter.get_input_details()[0] | |
| dummy = np.random.randint(0, 255, (1, 224, 224, 3), dtype=np.uint8) | |
| interpreter.set_tensor(input_details["index"], dummy) | |
| interpreter.invoke() | |
| times = [] | |
| for _ in range(n_runs): | |
| start = time.perf_counter() | |
| interpreter.set_tensor(input_details["index"], dummy) | |
| interpreter.invoke() | |
| times.append((time.perf_counter() - start) * 1000) | |
| times.sort() | |
| print(f"\nLatency over {n_runs} runs (CPU):") | |
| print(f" Median : {np.median(times):.1f} ms") | |
| print(f" p95 : {times[int(0.95 * n_runs)]:.1f} ms") | |
| print(f" Min : {times[0]:.1f} ms") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--saved_model", default="models/waste_classifier_v1") | |
| parser.add_argument("--data_dir", default="data/processed") | |
| parser.add_argument("--output", default="models/model.tflite") | |
| parser.add_argument("--benchmark", action="store_true") | |
| args = parser.parse_args() | |
| export(args.saved_model, args.data_dir, args.output) | |
| if args.benchmark: | |
| benchmark(args.output) | |