CLARIS / infrerence_examples /tflite_inference_example.py
RunningPie's picture
Upload folder using huggingface_hub
d354874 verified
import tensorflow as tf
import numpy as np
import pickle
class TFLiteSignLanguagePredictor:
"""TensorFlow Lite predictor for sign language recognition."""
def __init__(self, tflite_model_path, processor_path):
"""
Initialize the TFLite predictor.
Args:
tflite_model_path: Path to .tflite model file
processor_path: Path to processor.pkl file
"""
# Load TFLite model
self.interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
self.interpreter.allocate_tensors()
# Get input and output details
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
# Load processor
with open(processor_path, 'rb') as f:
self.processor = pickle.load(f)
print(f"TFLite model loaded successfully!")
print(f"Input shape: {self.input_details[0]['shape']}")
print(f"Output shape: {self.output_details[0]['shape']}")
def predict(self, landmark_data):
"""
Predict sign from landmark data.
Args:
landmark_data: DataFrame with landmark data or preprocessed array
Returns:
tuple: (predicted_class, confidence, all_probabilities)
"""
# Preprocess data if it's raw landmark data
if hasattr(landmark_data, 'columns'): # It's a DataFrame
X, _ = self.processor.process_dataset(landmark_data)
if len(X) == 0:
return None, 0.0, None
input_data = X[0:1].astype(np.float32)
else: # It's already preprocessed
input_data = landmark_data.astype(np.float32)
if len(input_data.shape) == 2: # Add batch dimension if needed
input_data = np.expand_dims(input_data, 0)
# Run inference
self.interpreter.set_tensor(self.input_details[0]['index'], input_data)
self.interpreter.invoke()
output_data = self.interpreter.get_tensor(self.output_details[0]['index'])
# Process results
probabilities = tf.nn.softmax(output_data[0]).numpy()
predicted_class = np.argmax(probabilities)
confidence = probabilities[predicted_class]
# Convert to sign name if mapping exists
if hasattr(self.processor, 'index_to_sign'):
sign_name = self.processor.index_to_sign[predicted_class]
return sign_name, confidence, probabilities
return predicted_class, confidence, probabilities
def predict_batch(self, batch_data):
"""Predict multiple samples (if model supports batch inference)."""
# Note: Some TFLite models may only support single inference
# This is a template for batch processing
results = []
for i in range(len(batch_data)):
result = self.predict(batch_data[i:i+1])
results.append(result)
return results
# Example usage
if __name__ == "__main__":
# Initialize predictor
predictor = TFLiteSignLanguagePredictor(
tflite_model_path="model_optimized.tflite", # or model.tflite, model_quantized.tflite
processor_path="processor.pkl"
)
# Example prediction (replace with your actual data)
# landmark_data = pd.read_csv("your_landmark_data.csv")
# predicted_sign, confidence, probabilities = predictor.predict(landmark_data)
# print(f"Predicted: {predicted_sign}, Confidence: {confidence:.3f}")
print("TFLite predictor ready for inference!")
# Performance comparison function
def compare_model_performance(keras_model, tflite_paths, test_data):
"""Compare performance between Keras and TFLite models."""
import time
print("\\n🏃 Performance Comparison:")
print("-" * 50)
# Test Keras model
start_time = time.time()
keras_pred = keras_model.predict(test_data)
keras_time = time.time() - start_time
print(f"Keras Model: {keras_time:.4f}s for {len(test_data)} samples")
# Test TFLite models
for name, path in tflite_paths.items():
interpreter = tf.lite.Interpreter(model_path=path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
start_time = time.time()
for i in range(len(test_data)):
test_input = test_data[i:i+1].astype(np.float32)
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
_ = interpreter.get_tensor(output_details[0]['index'])
tflite_time = time.time() - start_time
speedup = keras_time / tflite_time if tflite_time > 0 else 0
print(f"TFLite {name}: {tflite_time:.4f}s ({speedup:.1f}x {'faster' if speedup > 1 else 'slower'})")