Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| from model import TrafficSignDetector | |
| # Load the detector | |
| detector = TrafficSignDetector('config.yaml') | |
| def detect_traffic_signs(image, confidence_threshold): | |
| """ | |
| Process the uploaded image and return the image with detected signs. | |
| """ | |
| # Validate input image | |
| if image is None: | |
| print("No image provided") | |
| return None, None | |
| print(f"Received image type: {type(image)}") | |
| if hasattr(image, 'convert'): | |
| image = np.array(image) | |
| print(f"Converted PIL to numpy array, shape: {image.shape}") | |
| # Check if image is valid | |
| if image.size == 0 or len(image.shape) != 3: | |
| print(f"Invalid image: shape={image.shape}") | |
| return None, None | |
| # Convert RGB to BGR for OpenCV | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| print(f"Converted to BGR, shape: {image.shape}") | |
| # Perform detection | |
| result_image, preprocessed_image = detector.detect(image, confidence_threshold=confidence_threshold) | |
| # Convert back to RGB for Gradio | |
| result_image = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB) | |
| preprocessed_image = cv2.cvtColor(preprocessed_image, cv2.COLOR_BGR2RGB) | |
| return result_image, preprocessed_image | |
| # Create Gradio interface | |
| with gr.Blocks(title="Traffic Sign Detector") as demo: | |
| gr.Markdown("# Traffic Sign Detector") | |
| gr.Markdown("Upload an image to detect traffic signs using YOLOv8. Detection runs automatically when you upload or adjust the threshold.") | |
| with gr.Row(): | |
| input_image = gr.Image(label="Upload Image", type="pil") | |
| output_image = gr.Image(label="Detected Signs", interactive=False) | |
| with gr.Row(): | |
| preprocessed_image = gr.Image(label="Preprocessed Image (640x640, Letterboxed)", interactive=False) | |
| with gr.Row(): | |
| confidence_threshold = gr.Slider( | |
| minimum=0.01, | |
| maximum=0.9, | |
| value=0.30, | |
| step=0.01, | |
| label="Confidence Threshold", | |
| info="Lower values show more detections (less confident). Adjust to find optimal balance." | |
| ) | |
| with gr.Row(): | |
| detect_btn = gr.Button("Detect Traffic Signs", variant="primary") | |
| reset_btn = gr.Button("Clear") | |
| # Auto-detect on image upload | |
| input_image.change( | |
| fn=detect_traffic_signs, | |
| inputs=[input_image, confidence_threshold], | |
| outputs=[output_image, preprocessed_image], | |
| queue=True | |
| ) | |
| # Auto-detect on threshold change | |
| confidence_threshold.change( | |
| fn=detect_traffic_signs, | |
| inputs=[input_image, confidence_threshold], | |
| outputs=[output_image, preprocessed_image], | |
| queue=True | |
| ) | |
| # Manual detect button | |
| detect_btn.click( | |
| fn=detect_traffic_signs, | |
| inputs=[input_image, confidence_threshold], | |
| outputs=[output_image, preprocessed_image], | |
| queue=True | |
| ) | |
| # Clear button | |
| reset_btn.click( | |
| fn=lambda: (None, None, None, 0.30), | |
| outputs=[input_image, output_image, preprocessed_image, confidence_threshold] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |