File size: 4,811 Bytes
efb1801 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
#!/usr/bin/env python3
"""
Real-time strawberry detection/classification using TFLite model.
Supports both binary classification (good/bad) and YOLOv8 detection.
"""
import argparse
import cv2
import numpy as np
import tensorflow as tf
from pathlib import Path
import sys
def load_tflite_model(model_path):
"""Load TFLite model and allocate tensors."""
if not Path(model_path).exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
return interpreter
def get_model_details(interpreter):
"""Get input and output details of the TFLite model."""
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
return input_details, output_details
def preprocess_image(image, input_shape):
"""Preprocess image for model inference."""
height, width = input_shape[1:3] if len(input_shape) == 4 else input_shape[1:3]
img = cv2.resize(image, (width, height))
img = img / 255.0 # Normalize to [0,1]
img = np.expand_dims(img, axis=0).astype(np.float32)
return img
def run_inference(interpreter, input_details, output_details, preprocessed_img):
"""Run inference on preprocessed image."""
interpreter.set_tensor(input_details[0]['index'], preprocessed_img)
interpreter.invoke()
return interpreter.get_tensor(output_details[0]['index'])
def main():
parser = argparse.ArgumentParser(description='Real-time strawberry detection/classification')
parser.add_argument('--model', type=str, default='strawberry_model.tflite',
help='Path to TFLite model (default: strawberry_model.tflite)')
parser.add_argument('--camera', type=int, default=0,
help='Camera index (default: 0)')
parser.add_argument('--threshold', type=float, default=0.5,
help='Confidence threshold for binary classification (default: 0.5)')
parser.add_argument('--input-size', type=int, default=224,
help='Input image size (width=height) for model (default: 224)')
parser.add_argument('--mode', choices=['classification', 'detection'], default='classification',
help='Inference mode: classification (good/bad) or detection (YOLO)')
parser.add_argument('--verbose', action='store_true',
help='Print detailed inference information')
args = parser.parse_args()
# Load model
try:
interpreter = load_tflite_model(args.model)
input_details, output_details = get_model_details(interpreter)
input_shape = input_details[0]['shape']
if args.verbose:
print(f"Model loaded: {args.model}")
print(f"Input shape: {input_shape}")
print(f"Output details: {output_details[0]}")
except Exception as e:
print(f"Error loading model: {e}")
sys.exit(1)
# Open camera
cap = cv2.VideoCapture(args.camera)
if not cap.isOpened():
print(f"Cannot open camera index {args.camera}")
sys.exit(1)
print(f"Starting real-time inference (mode: {args.mode})")
print("Press 'q' to quit, 's' to save current frame")
while True:
ret, frame = cap.read()
if not ret:
print("Failed to capture frame")
break
# Preprocess
preprocessed = preprocess_image(frame, input_shape)
# Inference
predictions = run_inference(interpreter, input_details, output_details, preprocessed)
# Process predictions based on mode
if args.mode == 'classification':
# Binary classification: single probability
confidence = predictions[0][0]
label = 'Good' if confidence > args.threshold else 'Bad'
display_text = f'{label}: {confidence:.2f}'
color = (0, 255, 0) if confidence > args.threshold else (0, 0, 255)
else:
# Detection mode (YOLO) - placeholder for future implementation
display_text = 'Detection mode not yet implemented'
color = (255, 255, 0)
# Display
cv2.putText(frame, display_text, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
cv2.imshow('Strawberry Detection', frame)
key = cv2.waitKey(1) & 0xFF
if key == ord('q'):
break
elif key == ord('s'):
filename = f'capture_{cv2.getTickCount()}.jpg'
cv2.imwrite(filename, frame)
print(f"Frame saved as {filename}")
cap.release()
cv2.destroyAllWindows()
print("Real-time detection stopped.")
if __name__ == '__main__':
main() |