Spaces:
Sleeping
Sleeping
| # Standard library imports | |
| # (Add any necessary imports for future object detection implementation) | |
| # Third-party imports | |
| from ultralytics import YOLO | |
| # Local imports | |
| from utils.image_utils import load_image, preprocess_image | |
| YOLO_MODEL = "yolo11n.pt" | |
| # Load the YOLO model globally to avoid reloading on each function call | |
| # Using a common pre-trained YOLO nano model ('yolov8n.pt') | |
| model = None | |
| def load_model(): | |
| """ | |
| Load the YOLO model (YOLOv11 nano) for object detection. | |
| This function is called once at the start to avoid reloading the model multiple times. | |
| """ | |
| try: | |
| global model | |
| if model is not None: | |
| print("YOLO model already loaded.") | |
| return | |
| model = YOLO(YOLO_MODEL) | |
| print("YOLO model loaded successfully.") | |
| return | |
| except Exception as e: | |
| print(f"Error loading YOLO model: {e}") | |
| return | |
| def object_detection(input_type, uploaded_image, image_url, base64_string): | |
| """ | |
| Performs object detection on the image from various input types using YOLO (YOLOv11 nano). | |
| Args: | |
| input_type (str): The selected input method ("Upload File", "Enter URL", "Enter Base64"). | |
| uploaded_image (PIL.Image.Image): The uploaded image (if input_type is "Upload File"). | |
| image_url (str): The image URL (if input_type is "Enter URL"). | |
| base64_string (str): The image base64 string (if input_type is "Enter Base64"). | |
| Returns: | |
| tuple: A tuple containing: | |
| - numpy.ndarray: The image with detected objects drawn on it, or None if an error occurred or model not loaded. | |
| - dict: A dictionary containing the raw detection data (bounding boxes, classes, scores), or None. | |
| """ | |
| load_model() # Load the model if not already loaded | |
| if model is None: | |
| print("YOLO model is not loaded. Cannot perform object detection.") | |
| return None, None # Return None for both outputs | |
| image = None | |
| input_value = None | |
| if input_type == "Upload File" and uploaded_image is not None: | |
| image = uploaded_image # This is a PIL Image | |
| print("Using uploaded image (PIL) for object detection") # Debug print | |
| elif input_type == "Enter URL" and image_url and image_url.strip(): | |
| input_value = image_url | |
| print(f"Using URL for object detection: {input_value}") # Debug print | |
| elif input_type == "Enter Base64" and base64_string and base64_string.strip(): | |
| input_value = base64_string | |
| print("Using Base64 string for object detection") # Debug print | |
| else: | |
| print("No valid input provided for object detection based on selected type.") | |
| return None, None # Return None for both outputs | |
| # If input_value is set (URL or Base64), use load_image | |
| if input_value: | |
| image = load_image(input_value) | |
| if image is None: | |
| return None, None # load_image failed | |
| # Now 'image' should be a PIL Image or None | |
| if image is None: | |
| print("Image is None after loading/selection for object detection.") | |
| return None, None # Return None for both outputs | |
| try: | |
| # Preprocess the image (convert PIL to numpy, ensure RGB) | |
| processed_image_np = preprocess_image(image) | |
| # Perform inference | |
| results = model.predict(processed_image_np) | |
| # Extract raw detection data | |
| raw_data = [] | |
| if results and results[0].boxes: | |
| for box in results[0].boxes: | |
| # box.xywh contains [x_center, y_center, width, height] | |
| # box.conf contains confidence score | |
| # box.cls contains class index | |
| x_center, y_center, width, height = [ | |
| round(float(coord)) for coord in box.xywh[0].tolist() | |
| ] # Changed to xywh | |
| confidence = round(float(box.conf[0]), 4) | |
| class_id = int(box.cls[0]) | |
| class_name = ( | |
| model.names[class_id] if model.names else str(class_id) | |
| ) # Get class name if available | |
| raw_data.append( | |
| { | |
| "box": { | |
| "x": x_center, | |
| "y": y_center, | |
| "w": width, | |
| "h": height, | |
| }, # Updated keys | |
| "confidence": confidence, | |
| "class_id": class_id, | |
| "class_name": class_name, | |
| } | |
| ) | |
| # Draw results on the image | |
| result_image_np = ( | |
| results[0].plot() if results else processed_image_np | |
| ) # Plot if results exist | |
| print("Object detection performed successfully.") | |
| return result_image_np, raw_data # Return both the image and raw data | |
| except Exception as e: | |
| print(f"Error during YOLO object detection: {e}") | |
| return None, None # Return None for both outputs | |