File size: 3,234 Bytes
b11642e
 
 
 
 
 
 
 
81a35ad
b11642e
 
 
cfe01e5
 
 
 
 
46ed9eb
b11642e
 
46ed9eb
cfe01e5
 
 
 
 
46ed9eb
b11642e
 
46ed9eb
 
6468f09
81a35ad
46ed9eb
b11642e
 
eedb862
46ed9eb
eedb862
b11642e
 
6468f09
b11642e
f98ab8d
b11642e
 
 
734fa17
b11642e
eedb862
74ac36a
eedb862
81a35ad
 
 
 
 
 
f98ab8d
 
81a35ad
 
74ac36a
 
 
 
f98ab8d
74ac36a
86b9261
74ac36a
86b9261
74ac36a
 
 
6468f09
74ac36a
86b9261
74ac36a
86b9261
74ac36a
 
 
6468f09
eedb862
86b9261
81a35ad
86b9261
74ac36a
 
 
f98ab8d
74ac36a
f98ab8d
 
eedb862
b11642e
 
eedb862
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import cv2
import numpy as np
from model import TrafficSignDetector

# Load the detector
detector = TrafficSignDetector('config.yaml')

def detect_traffic_signs(image, confidence_threshold):
    """
    Process the uploaded image and return the image with detected signs.
    """
    # Validate input image
    if image is None:
        print("No image provided")
        return None, None
    
    print(f"Received image type: {type(image)}")
    if hasattr(image, 'convert'):
        image = np.array(image)
        print(f"Converted PIL to numpy array, shape: {image.shape}")
    
    # Check if image is valid
    if image.size == 0 or len(image.shape) != 3:
        print(f"Invalid image: shape={image.shape}")
        return None, None

    # Convert RGB to BGR for OpenCV
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    print(f"Converted to BGR, shape: {image.shape}")

    # Perform detection
    result_image, preprocessed_image = detector.detect(image, confidence_threshold=confidence_threshold)

    # Convert back to RGB for Gradio
    result_image = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
    preprocessed_image = cv2.cvtColor(preprocessed_image, cv2.COLOR_BGR2RGB)

    return result_image, preprocessed_image

# Create Gradio interface
with gr.Blocks(title="Traffic Sign Detector") as demo:
    gr.Markdown("# Traffic Sign Detector")
    gr.Markdown("Upload an image to detect traffic signs using YOLOv8. Detection runs automatically when you upload or adjust the threshold.")
    
    with gr.Row():
        input_image = gr.Image(label="Upload Image", type="pil")
        output_image = gr.Image(label="Detected Signs", interactive=False)
    
    with gr.Row():
        preprocessed_image = gr.Image(label="Preprocessed Image (640x640, Letterboxed)", interactive=False)
    
    with gr.Row():
        confidence_threshold = gr.Slider(
            minimum=0.01,
            maximum=0.9,
            value=0.30,
            step=0.01,
            label="Confidence Threshold",
            info="Lower values show more detections (less confident). Adjust to find optimal balance."
        )
    
    with gr.Row():
        detect_btn = gr.Button("Detect Traffic Signs", variant="primary")
        reset_btn = gr.Button("Clear")
    
    # Auto-detect on image upload
    input_image.change(
        fn=detect_traffic_signs,
        inputs=[input_image, confidence_threshold],
        outputs=[output_image, preprocessed_image],
        queue=True
    )
    
    # Auto-detect on threshold change
    confidence_threshold.change(
        fn=detect_traffic_signs,
        inputs=[input_image, confidence_threshold],
        outputs=[output_image, preprocessed_image],
        queue=True
    )
    
    # Manual detect button
    detect_btn.click(
        fn=detect_traffic_signs, 
        inputs=[input_image, confidence_threshold], 
        outputs=[output_image, preprocessed_image],
        queue=True
    )
    
    # Clear button
    reset_btn.click(
        fn=lambda: (None, None, None, 0.30),
        outputs=[input_image, output_image, preprocessed_image, confidence_threshold]
    )

if __name__ == "__main__":
    demo.queue().launch(server_name="0.0.0.0", server_port=7860)