VietCat's picture
Add input validation to prevent None image errors
cfe01e5
import gradio as gr
import cv2
import numpy as np
from model import TrafficSignDetector
# Load the detector
detector = TrafficSignDetector('config.yaml')
def detect_traffic_signs(image, confidence_threshold):
"""
Process the uploaded image and return the image with detected signs.
"""
# Validate input image
if image is None:
print("No image provided")
return None, None
print(f"Received image type: {type(image)}")
if hasattr(image, 'convert'):
image = np.array(image)
print(f"Converted PIL to numpy array, shape: {image.shape}")
# Check if image is valid
if image.size == 0 or len(image.shape) != 3:
print(f"Invalid image: shape={image.shape}")
return None, None
# Convert RGB to BGR for OpenCV
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
print(f"Converted to BGR, shape: {image.shape}")
# Perform detection
result_image, preprocessed_image = detector.detect(image, confidence_threshold=confidence_threshold)
# Convert back to RGB for Gradio
result_image = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
preprocessed_image = cv2.cvtColor(preprocessed_image, cv2.COLOR_BGR2RGB)
return result_image, preprocessed_image
# Create Gradio interface
with gr.Blocks(title="Traffic Sign Detector") as demo:
gr.Markdown("# Traffic Sign Detector")
gr.Markdown("Upload an image to detect traffic signs using YOLOv8. Detection runs automatically when you upload or adjust the threshold.")
with gr.Row():
input_image = gr.Image(label="Upload Image", type="pil")
output_image = gr.Image(label="Detected Signs", interactive=False)
with gr.Row():
preprocessed_image = gr.Image(label="Preprocessed Image (640x640, Letterboxed)", interactive=False)
with gr.Row():
confidence_threshold = gr.Slider(
minimum=0.01,
maximum=0.9,
value=0.30,
step=0.01,
label="Confidence Threshold",
info="Lower values show more detections (less confident). Adjust to find optimal balance."
)
with gr.Row():
detect_btn = gr.Button("Detect Traffic Signs", variant="primary")
reset_btn = gr.Button("Clear")
# Auto-detect on image upload
input_image.change(
fn=detect_traffic_signs,
inputs=[input_image, confidence_threshold],
outputs=[output_image, preprocessed_image],
queue=True
)
# Auto-detect on threshold change
confidence_threshold.change(
fn=detect_traffic_signs,
inputs=[input_image, confidence_threshold],
outputs=[output_image, preprocessed_image],
queue=True
)
# Manual detect button
detect_btn.click(
fn=detect_traffic_signs,
inputs=[input_image, confidence_threshold],
outputs=[output_image, preprocessed_image],
queue=True
)
# Clear button
reset_btn.click(
fn=lambda: (None, None, None, 0.30),
outputs=[input_image, output_image, preprocessed_image, confidence_threshold]
)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)