| | |
| | """ |
| | Person Classification TensorFlow Lite Inference Example |
| | |
| | This script demonstrates how to use the person classification TFLite models |
| | for binary classification (person vs. non-person) on input images. |
| | |
| | Usage: |
| | python inference_example.py --model flash --image test_image.jpg |
| | python inference_example.py --model sram --image test_image.jpg |
| | """ |
| |
|
| | import argparse |
| | import sys |
| | import os |
| | import time |
| | import numpy as np |
| | import tensorflow as tf |
| | from PIL import Image |
| | from pathlib import Path |
| |
|
| | class PersonClassifier: |
| | """Person Classification using TensorFlow Lite models""" |
| | |
| | def __init__(self, model_path): |
| | """ |
| | Initialize the classifier with a TFLite model |
| | |
| | Args: |
| | model_path (str): Path to the .tflite model file |
| | """ |
| | self.model_path = model_path |
| | self.interpreter = None |
| | self.input_details = None |
| | self.output_details = None |
| | self.input_shape = None |
| | self.load_model() |
| | |
| | def load_model(self): |
| | """Load the TensorFlow Lite model""" |
| | try: |
| | self.interpreter = tf.lite.Interpreter(model_path=self.model_path) |
| | self.interpreter.allocate_tensors() |
| | |
| | self.input_details = self.interpreter.get_input_details() |
| | self.output_details = self.interpreter.get_output_details() |
| | self.input_shape = self.input_details[0]['shape'] |
| | |
| | print(f"β
Model loaded successfully: {self.model_path}") |
| | print(f"π Input shape: {self.input_shape}") |
| | print(f"π Input dtype: {self.input_details[0]['dtype']}") |
| | print(f"π Output shape: {self.output_details[0]['shape']}") |
| | print(f"π Output dtype: {self.output_details[0]['dtype']}") |
| | |
| | except Exception as e: |
| | print(f"β Error loading model: {e}") |
| | sys.exit(1) |
| | |
| | def preprocess_image(self, image_path): |
| | """ |
| | Preprocess input image for model inference |
| | |
| | Args: |
| | image_path (str): Path to input image |
| | |
| | Returns: |
| | np.ndarray: Preprocessed image array ready for inference |
| | """ |
| | try: |
| | |
| | image = Image.open(image_path).convert('RGB') |
| | print(f"πΈ Original image size: {image.size}") |
| | |
| | |
| | target_height = self.input_shape[1] |
| | target_width = self.input_shape[2] |
| | target_size = (target_width, target_height) |
| | |
| | |
| | image = image.resize(target_size, Image.Resampling.BILINEAR) |
| | print(f"π Resized to: {target_size} (WΓH)") |
| | |
| | |
| | image_array = np.array(image, dtype=np.uint8) |
| | |
| | |
| | image_batch = np.expand_dims(image_array, axis=0) |
| | |
| | print(f"β
Preprocessed shape: {image_batch.shape}") |
| | print(f"π Value range: [{image_batch.min()}, {image_batch.max()}]") |
| | |
| | return image_batch |
| | |
| | except Exception as e: |
| | print(f"β Error preprocessing image: {e}") |
| | sys.exit(1) |
| | |
| | def predict(self, image_data): |
| | """ |
| | Run inference on preprocessed image data |
| | |
| | Args: |
| | image_data (np.ndarray): Preprocessed image data |
| | |
| | Returns: |
| | tuple: (probability, prediction_label, confidence) |
| | """ |
| | try: |
| | |
| | self.interpreter.set_tensor(self.input_details[0]['index'], image_data) |
| | |
| | |
| | start_time = time.time() |
| | self.interpreter.invoke() |
| | inference_time = time.time() - start_time |
| | |
| | |
| | output_data = self.interpreter.get_tensor(self.output_details[0]['index']) |
| | |
| | |
| | scale = self.output_details[0]['quantization'][0] |
| | zero_point = self.output_details[0]['quantization'][1] |
| | |
| | if scale != 0: |
| | |
| | dequantized_output = scale * (output_data.astype(np.float32) - zero_point) |
| | |
| | probability = 1 / (1 + np.exp(-dequantized_output[0][0])) |
| | print(f"π’ Quantized output dequantized: {dequantized_output[0][0]:.4f}") |
| | else: |
| | probability = float(output_data[0][0]) |
| | |
| | |
| | prediction_label = "Person" if probability > 0.5 else "Non-person" |
| | confidence = probability if probability > 0.5 else (1 - probability) |
| | |
| | print(f"β±οΈ Inference time: {inference_time*1000:.2f}ms") |
| | |
| | return probability, prediction_label, confidence |
| | |
| | except Exception as e: |
| | print(f"β Error during inference: {e}") |
| | sys.exit(1) |
| | |
| | def classify_image(self, image_path): |
| | """ |
| | Complete pipeline: preprocess image and run classification |
| | |
| | Args: |
| | image_path (str): Path to input image |
| | |
| | Returns: |
| | dict: Classification results |
| | """ |
| | print(f"\nπ Classifying image: {image_path}") |
| | print("=" * 50) |
| | |
| | |
| | image_data = self.preprocess_image(image_path) |
| | |
| | |
| | probability, prediction_label, confidence = self.predict(image_data) |
| | |
| | |
| | results = { |
| | 'image_path': image_path, |
| | 'prediction': prediction_label, |
| | 'probability': probability, |
| | 'confidence': confidence, |
| | 'model_used': self.model_path |
| | } |
| | |
| | return results |
| |
|
| | def print_results(results): |
| | """Print classification results in a formatted way""" |
| | print("\nπ CLASSIFICATION RESULTS") |
| | print("=" * 50) |
| | print(f"πΌοΈ Image: {results['image_path']}") |
| | print(f"π― Prediction: {results['prediction']}") |
| | print(f"π Probability: {results['probability']:.4f}") |
| | print(f"β
Confidence: {results['confidence']:.1%}") |
| | print(f"π€ Model: {Path(results['model_used']).name}") |
| | print("=" * 50) |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="Person Classification using TensorFlow Lite models", |
| | formatter_class=argparse.RawDescriptionHelpFormatter, |
| | epilog=""" |
| | Examples: |
| | python inference_example.py --model flash --image test_image.jpg |
| | python inference_example.py --model sram --image test_image.jpg |
| | python inference_example.py --model person_classification_flash(448x640).tflite --image test_image.jpg |
| | """ |
| | ) |
| | |
| | parser.add_argument( |
| | '--model', |
| | required=True, |
| | help='Model to use: "flash", "sram", or path to .tflite file' |
| | ) |
| | |
| | parser.add_argument( |
| | '--image', |
| | required=True, |
| | help='Path to input image file' |
| | ) |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | if args.model.lower() == 'flash': |
| | model_path = 'person_classification_flash(448x640).tflite' |
| | elif args.model.lower() == 'sram': |
| | model_path = 'person_classification_sram(256x448).tflite' |
| | else: |
| | model_path = args.model |
| | |
| | |
| | if not os.path.exists(model_path): |
| | print(f"β Model file not found: {model_path}") |
| | sys.exit(1) |
| | |
| | |
| | if not os.path.exists(args.image): |
| | print(f"β Image file not found: {args.image}") |
| | sys.exit(1) |
| | |
| | print("π Person Classification TensorFlow Lite Demo") |
| | print("=" * 50) |
| | |
| | |
| | classifier = PersonClassifier(model_path) |
| | |
| | |
| | results = classifier.classify_image(args.image) |
| | |
| | |
| | print_results(results) |
| |
|
| | def demo_both_models(): |
| | """Demo function to test both models if available""" |
| | print("π Person Classification Demo - Both Models") |
| | print("=" * 50) |
| | |
| | models = [ |
| | ('Flash Model (VGA)', 'person_classification_flash(448x640).tflite'), |
| | ('SRAM Model (WQVGA)', 'person_classification_sram(256x448).tflite') |
| | ] |
| | |
| | |
| | test_image_path = 'test_person.jpg' |
| | if not os.path.exists(test_image_path): |
| | print(f"βΉοΈ Creating test image: {test_image_path}") |
| | |
| | test_img = Image.new('RGB', (640, 480), color='lightblue') |
| | test_img.save(test_image_path) |
| | |
| | for model_name, model_path in models: |
| | if os.path.exists(model_path): |
| | print(f"\nπ Testing {model_name}") |
| | print("-" * 30) |
| | |
| | classifier = PersonClassifier(model_path) |
| | results = classifier.classify_image(test_image_path) |
| | print_results(results) |
| | else: |
| | print(f"β οΈ {model_name} not found: {model_path}") |
| |
|
| | if __name__ == '__main__': |
| | if len(sys.argv) == 1: |
| | |
| | demo_both_models() |
| | else: |
| | main() |