strawberryPicker / scripts /convert_tflite.py
Gareth
Initial clean commit for Hugging Face
efb1801
#!/usr/bin/env python3
"""
Convert trained model to TensorFlow Lite format with optional quantization.
Supports conversion from Keras (.h5) and PyTorch (.pt) models.
"""
import argparse
import tensorflow as tf
import os
from pathlib import Path
import sys
def convert_keras_to_tflite(input_path, output_path, quantization=None, optimize_for_size=False):
"""Convert Keras model to TFLite format."""
if not os.path.exists(input_path):
raise FileNotFoundError(f"Input model not found: {input_path}")
print(f"Loading Keras model from {input_path}...")
model = tf.keras.models.load_model(input_path)
# Configure converter
converter = tf.lite.TFLiteConverter.from_keras_model(model)
# Optimization options
if optimize_for_size:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Quantization options
if quantization == 'int8':
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.int8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
print("Using INT8 quantization")
elif quantization == 'float16':
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
print("Using Float16 quantization")
elif quantization == 'dynamic_range':
converter.optimizations = [tf.lite.Optimize.DEFAULT]
print("Using dynamic range quantization")
# Convert model
print("Converting model to TFLite...")
tflite_model = converter.convert()
# Save model
with open(output_path, 'wb') as f:
f.write(tflite_model)
print(f"TFLite model saved to {output_path}")
# Print model size
size_kb = os.path.getsize(output_path) / 1024
print(f"Model size: {size_kb:.2f} KB")
return output_path
def convert_pytorch_to_tflite(input_path, output_path, quantization=None):
"""Convert PyTorch model to TFLite format (placeholder for future implementation)."""
print("PyTorch to TFLite conversion not yet implemented.")
print("Please convert PyTorch model to ONNX first, then use TensorFlow's converter.")
return None
def main():
parser = argparse.ArgumentParser(description='Convert model to TensorFlow Lite format')
parser.add_argument('--input', type=str, default='strawberry_model.h5',
help='Input model path (Keras .h5 or PyTorch .pt)')
parser.add_argument('--output', type=str, default='strawberry_model.tflite',
help='Output TFLite model path')
parser.add_argument('--quantization', type=str, choices=['int8', 'float16', 'dynamic_range', 'none'],
default='none', help='Quantization method (default: none)')
parser.add_argument('--optimize-for-size', action='store_true',
help='Apply size optimization (reduces model size)')
parser.add_argument('--model-type', type=str, choices=['keras', 'pytorch'], default='keras',
help='Type of input model (default: keras)')
args = parser.parse_args()
# Validate input file exists
if not Path(args.input).exists():
print(f"Error: Input model '{args.input}' not found.")
print("Available model files in current directory:")
for f in os.listdir('.'):
if f.endswith(('.h5', '.pt', '.pth', '.onnx')):
print(f" - {f}")
sys.exit(1)
# Create output directory if needed
output_dir = os.path.dirname(args.output)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
try:
if args.model_type == 'keras':
convert_keras_to_tflite(
input_path=args.input,
output_path=args.output,
quantization=args.quantization if args.quantization != 'none' else None,
optimize_for_size=args.optimize_for_size
)
elif args.model_type == 'pytorch':
convert_pytorch_to_tflite(
input_path=args.input,
output_path=args.output,
quantization=args.quantization if args.quantization != 'none' else None
)
else:
print(f"Unsupported model type: {args.model_type}")
sys.exit(1)
print("Conversion completed successfully!")
except Exception as e:
print(f"Error during conversion: {e}")
sys.exit(1)
if __name__ == '__main__':
main()