person-classification-tflite / inference_example.py
Jahnavibh's picture
Initial commit: Person Classification TensorFlow Lite Models
0eef5f7
#!/usr/bin/env python3
"""
Person Classification TensorFlow Lite Inference Example
This script demonstrates how to use the person classification TFLite models
for binary classification (person vs. non-person) on input images.
Usage:
python inference_example.py --model flash --image test_image.jpg
python inference_example.py --model sram --image test_image.jpg
"""
import argparse
import sys
import os
import time
import numpy as np
import tensorflow as tf
from PIL import Image
from pathlib import Path
class PersonClassifier:
"""Person Classification using TensorFlow Lite models"""
def __init__(self, model_path):
"""
Initialize the classifier with a TFLite model
Args:
model_path (str): Path to the .tflite model file
"""
self.model_path = model_path
self.interpreter = None
self.input_details = None
self.output_details = None
self.input_shape = None
self.load_model()
def load_model(self):
"""Load the TensorFlow Lite model"""
try:
self.interpreter = tf.lite.Interpreter(model_path=self.model_path)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
self.input_shape = self.input_details[0]['shape']
print(f"βœ… Model loaded successfully: {self.model_path}")
print(f"πŸ“Š Input shape: {self.input_shape}")
print(f"πŸ“Š Input dtype: {self.input_details[0]['dtype']}")
print(f"πŸ“Š Output shape: {self.output_details[0]['shape']}")
print(f"πŸ“Š Output dtype: {self.output_details[0]['dtype']}")
except Exception as e:
print(f"❌ Error loading model: {e}")
sys.exit(1)
def preprocess_image(self, image_path):
"""
Preprocess input image for model inference
Args:
image_path (str): Path to input image
Returns:
np.ndarray: Preprocessed image array ready for inference
"""
try:
# Load and convert image
image = Image.open(image_path).convert('RGB')
print(f"πŸ“Έ Original image size: {image.size}")
# Get target size from model input shape (height, width)
target_height = self.input_shape[1]
target_width = self.input_shape[2]
target_size = (target_width, target_height) # PIL uses (width, height)
# Resize image to model's expected input size
image = image.resize(target_size, Image.Resampling.BILINEAR)
print(f"πŸ”„ Resized to: {target_size} (WΓ—H)")
# Convert to numpy array
image_array = np.array(image, dtype=np.uint8)
# Add batch dimension [batch, height, width, channels]
image_batch = np.expand_dims(image_array, axis=0)
print(f"βœ… Preprocessed shape: {image_batch.shape}")
print(f"πŸ“Š Value range: [{image_batch.min()}, {image_batch.max()}]")
return image_batch
except Exception as e:
print(f"❌ Error preprocessing image: {e}")
sys.exit(1)
def predict(self, image_data):
"""
Run inference on preprocessed image data
Args:
image_data (np.ndarray): Preprocessed image data
Returns:
tuple: (probability, prediction_label, confidence)
"""
try:
# Set input tensor
self.interpreter.set_tensor(self.input_details[0]['index'], image_data)
# Run inference
start_time = time.time()
self.interpreter.invoke()
inference_time = time.time() - start_time
# Get output tensor
output_data = self.interpreter.get_tensor(self.output_details[0]['index'])
# Handle quantized vs float output
scale = self.output_details[0]['quantization'][0]
zero_point = self.output_details[0]['quantization'][1]
if scale != 0: # Quantized output
# Dequantize
dequantized_output = scale * (output_data.astype(np.float32) - zero_point)
# Apply sigmoid to get probability
probability = 1 / (1 + np.exp(-dequantized_output[0][0]))
print(f"πŸ”’ Quantized output dequantized: {dequantized_output[0][0]:.4f}")
else: # Float output
probability = float(output_data[0][0])
# Determine prediction
prediction_label = "Person" if probability > 0.5 else "Non-person"
confidence = probability if probability > 0.5 else (1 - probability)
print(f"⏱️ Inference time: {inference_time*1000:.2f}ms")
return probability, prediction_label, confidence
except Exception as e:
print(f"❌ Error during inference: {e}")
sys.exit(1)
def classify_image(self, image_path):
"""
Complete pipeline: preprocess image and run classification
Args:
image_path (str): Path to input image
Returns:
dict: Classification results
"""
print(f"\nπŸ” Classifying image: {image_path}")
print("=" * 50)
# Preprocess image
image_data = self.preprocess_image(image_path)
# Run inference
probability, prediction_label, confidence = self.predict(image_data)
# Compile results
results = {
'image_path': image_path,
'prediction': prediction_label,
'probability': probability,
'confidence': confidence,
'model_used': self.model_path
}
return results
def print_results(results):
"""Print classification results in a formatted way"""
print("\nπŸ“‹ CLASSIFICATION RESULTS")
print("=" * 50)
print(f"πŸ–ΌοΈ Image: {results['image_path']}")
print(f"🎯 Prediction: {results['prediction']}")
print(f"πŸ“Š Probability: {results['probability']:.4f}")
print(f"βœ… Confidence: {results['confidence']:.1%}")
print(f"πŸ€– Model: {Path(results['model_used']).name}")
print("=" * 50)
def main():
parser = argparse.ArgumentParser(
description="Person Classification using TensorFlow Lite models",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python inference_example.py --model flash --image test_image.jpg
python inference_example.py --model sram --image test_image.jpg
python inference_example.py --model person_classification_flash(448x640).tflite --image test_image.jpg
"""
)
parser.add_argument(
'--model',
required=True,
help='Model to use: "flash", "sram", or path to .tflite file'
)
parser.add_argument(
'--image',
required=True,
help='Path to input image file'
)
args = parser.parse_args()
# Determine model path
if args.model.lower() == 'flash':
model_path = 'person_classification_flash(448x640).tflite'
elif args.model.lower() == 'sram':
model_path = 'person_classification_sram(256x448).tflite'
else:
model_path = args.model
# Check if model file exists
if not os.path.exists(model_path):
print(f"❌ Model file not found: {model_path}")
sys.exit(1)
# Check if image file exists
if not os.path.exists(args.image):
print(f"❌ Image file not found: {args.image}")
sys.exit(1)
print("πŸš€ Person Classification TensorFlow Lite Demo")
print("=" * 50)
# Initialize classifier
classifier = PersonClassifier(model_path)
# Run classification
results = classifier.classify_image(args.image)
# Print results
print_results(results)
def demo_both_models():
"""Demo function to test both models if available"""
print("πŸš€ Person Classification Demo - Both Models")
print("=" * 50)
models = [
('Flash Model (VGA)', 'person_classification_flash(448x640).tflite'),
('SRAM Model (WQVGA)', 'person_classification_sram(256x448).tflite')
]
# Create a simple test image if none exists
test_image_path = 'test_person.jpg'
if not os.path.exists(test_image_path):
print(f"ℹ️ Creating test image: {test_image_path}")
# Create a simple test image (colored rectangle)
test_img = Image.new('RGB', (640, 480), color='lightblue')
test_img.save(test_image_path)
for model_name, model_path in models:
if os.path.exists(model_path):
print(f"\nπŸ” Testing {model_name}")
print("-" * 30)
classifier = PersonClassifier(model_path)
results = classifier.classify_image(test_image_path)
print_results(results)
else:
print(f"⚠️ {model_name} not found: {model_path}")
if __name__ == '__main__':
if len(sys.argv) == 1:
# If no arguments provided, run demo
demo_both_models()
else:
main()