File size: 4,811 Bytes
efb1801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
"""
Real-time strawberry detection/classification using TFLite model.
Supports both binary classification (good/bad) and YOLOv8 detection.
"""

import argparse
import cv2
import numpy as np
import tensorflow as tf
from pathlib import Path
import sys

def load_tflite_model(model_path):
    """Load TFLite model and allocate tensors."""
    if not Path(model_path).exists():
        raise FileNotFoundError(f"Model file not found: {model_path}")
    
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()
    return interpreter

def get_model_details(interpreter):
    """Get input and output details of the TFLite model."""
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    return input_details, output_details

def preprocess_image(image, input_shape):
    """Preprocess image for model inference."""
    height, width = input_shape[1:3] if len(input_shape) == 4 else input_shape[1:3]
    img = cv2.resize(image, (width, height))
    img = img / 255.0  # Normalize to [0,1]
    img = np.expand_dims(img, axis=0).astype(np.float32)
    return img

def run_inference(interpreter, input_details, output_details, preprocessed_img):
    """Run inference on preprocessed image."""
    interpreter.set_tensor(input_details[0]['index'], preprocessed_img)
    interpreter.invoke()
    return interpreter.get_tensor(output_details[0]['index'])

def main():
    parser = argparse.ArgumentParser(description='Real-time strawberry detection/classification')
    parser.add_argument('--model', type=str, default='strawberry_model.tflite',
                        help='Path to TFLite model (default: strawberry_model.tflite)')
    parser.add_argument('--camera', type=int, default=0,
                        help='Camera index (default: 0)')
    parser.add_argument('--threshold', type=float, default=0.5,
                        help='Confidence threshold for binary classification (default: 0.5)')
    parser.add_argument('--input-size', type=int, default=224,
                        help='Input image size (width=height) for model (default: 224)')
    parser.add_argument('--mode', choices=['classification', 'detection'], default='classification',
                        help='Inference mode: classification (good/bad) or detection (YOLO)')
    parser.add_argument('--verbose', action='store_true',
                        help='Print detailed inference information')
    
    args = parser.parse_args()
    
    # Load model
    try:
        interpreter = load_tflite_model(args.model)
        input_details, output_details = get_model_details(interpreter)
        input_shape = input_details[0]['shape']
        if args.verbose:
            print(f"Model loaded: {args.model}")
            print(f"Input shape: {input_shape}")
            print(f"Output details: {output_details[0]}")
    except Exception as e:
        print(f"Error loading model: {e}")
        sys.exit(1)
    
    # Open camera
    cap = cv2.VideoCapture(args.camera)
    if not cap.isOpened():
        print(f"Cannot open camera index {args.camera}")
        sys.exit(1)
    
    print(f"Starting real-time inference (mode: {args.mode})")
    print("Press 'q' to quit, 's' to save current frame")
    
    while True:
        ret, frame = cap.read()
        if not ret:
            print("Failed to capture frame")
            break
        
        # Preprocess
        preprocessed = preprocess_image(frame, input_shape)
        
        # Inference
        predictions = run_inference(interpreter, input_details, output_details, preprocessed)
        
        # Process predictions based on mode
        if args.mode == 'classification':
            # Binary classification: single probability
            confidence = predictions[0][0]
            label = 'Good' if confidence > args.threshold else 'Bad'
            display_text = f'{label}: {confidence:.2f}'
            color = (0, 255, 0) if confidence > args.threshold else (0, 0, 255)
        else:
            # Detection mode (YOLO) - placeholder for future implementation
            display_text = 'Detection mode not yet implemented'
            color = (255, 255, 0)
        
        # Display
        cv2.putText(frame, display_text, (10, 30),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
        cv2.imshow('Strawberry Detection', frame)
        
        key = cv2.waitKey(1) & 0xFF
        if key == ord('q'):
            break
        elif key == ord('s'):
            filename = f'capture_{cv2.getTickCount()}.jpg'
            cv2.imwrite(filename, frame)
            print(f"Frame saved as {filename}")
    
    cap.release()
    cv2.destroyAllWindows()
    print("Real-time detection stopped.")

if __name__ == '__main__':
    main()