Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| from tempfile import NamedTemporaryFile | |
| from ultralytics import YOLO | |
| # Load YOLOv8 model (assuming 'best.pt' is in the same directory) | |
| model = YOLO("best.pt") | |
| # Define function for object detection with clear comments | |
| def detect_objects(image, classes): | |
| """Performs object detection on an image using the loaded YOLOv8 model. | |
| Args: | |
| image: A PIL Image object representing the input image. | |
| classes: A list of class names. | |
| Returns: | |
| A PIL Image object with segmentations overlaid or the original image | |
| if an error occurs. Handles multiple detections and conversion | |
| to PIL Image format as needed. | |
| """ | |
| try: | |
| # Save the uploaded image to a temporary file | |
| with NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: | |
| if image is not None: | |
| image.save(temp_file.name, format="JPEG") | |
| # Perform detection using the model | |
| results = model.predict(source=temp_file.name, save=False, imgsz=320, conf=0.5) | |
| # Initialize an empty list to store annotated images | |
| annotated_images = [] | |
| # If results is a list, loop through each result | |
| if isinstance(results, list): | |
| for result in results: | |
| # Plot detection results on the original image | |
| annotated_image = result.plot() | |
| annotated_images.append(annotated_image) | |
| else: | |
| # Plot detection results on the original image | |
| annotated_image = results.plot() | |
| annotated_images.append(annotated_image) | |
| return annotated_images | |
| except Exception as e: | |
| st.error(f"An error occurred during object segmentation: {e}") | |
| return [image] # Return original image in case of errors | |
| # Function to perform object segmentation on video frames | |
| def detect_objects_video(video_file, classes): | |
| # Open the video file | |
| video = cv2.VideoCapture(video_file.name) | |
| if not video.isOpened(): | |
| st.error("Error: Unable to open video file.") | |
| return | |
| # Create a temporary file to store the annotated video | |
| temp_video_file = NamedTemporaryFile(delete=False, suffix=".mp4") | |
| # Get video properties | |
| frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = int(video.get(cv2.CAP_PROP_FPS)) | |
| # Create VideoWriter object to save the annotated video | |
| out = cv2.VideoWriter(temp_video_file.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height)) | |
| # Read until video is completed | |
| while video.isOpened(): | |
| ret, frame = video.read() | |
| if not ret: | |
| break | |
| # Convert frame to PIL Image | |
| frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| # Perform object detection on the frame | |
| annotated_frame = detect_objects(frame_pil, classes) | |
| # Convert annotated frame back to numpy array | |
| annotated_frame_np = np.array(annotated_frame[0]) | |
| # Write the annotated frame to the output video | |
| out.write(cv2.cvtColor(annotated_frame_np, cv2.COLOR_RGB2BGR)) | |
| # Release the video objects | |
| video.release() | |
| out.release() | |
| return temp_video_file | |
| # Streamlit app | |
| st.title("YOLOv8 Object Segmentation") | |
| # Upload image or video | |
| uploaded_file = st.file_uploader("Upload Image or Video", type=["jpg", "jpeg", "png", "mp4"]) | |
| if uploaded_file is not None: | |
| # Check if the uploaded file is a video | |
| is_video = uploaded_file.name.endswith(".mp4") | |
| if is_video: | |
| # Perform object segmentation on video | |
| st.write("Performing object segmentation on video...") | |
| try: | |
| detected_video = detect_objects_video(uploaded_file, classes=['COW', 'Cattle', 'horse', 'pig', 'sheep', 'undefined']) | |
| st.video(detected_video.name) | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| else: | |
| # Perform object segmentation on image | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption='Uploaded Image', use_column_width=True) | |
| st.write("Performing object segmentation...") | |
| try: | |
| detected_segmentation = detect_objects(image, classes=['COW', 'Cattle', 'horse', 'pig', 'sheep', 'undefined']) | |
| for annotated_image in detected_segmentation: | |
| st.image(annotated_image, caption='Segmentation Mask', use_column_width=True) | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |