farm-count / app.py
nishu692003's picture
Create app.py
155ac0a verified
import streamlit as st
from PIL import Image
import cv2
import numpy as np
from tempfile import NamedTemporaryFile
from ultralytics import YOLO
# Load YOLOv8 model (assuming 'best.pt' is in the same directory)
model = YOLO("best.pt")
# Define function for object detection with clear comments
def detect_objects(image, classes):
"""Performs object detection on an image using the loaded YOLOv8 model.
Args:
image: A PIL Image object representing the input image.
classes: A list of class names.
Returns:
A PIL Image object with segmentations overlaid or the original image
if an error occurs. Handles multiple detections and conversion
to PIL Image format as needed.
"""
try:
# Save the uploaded image to a temporary file
with NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
if image is not None:
image.save(temp_file.name, format="JPEG")
# Perform detection using the model
results = model.predict(source=temp_file.name, save=False, imgsz=320, conf=0.5)
# Initialize an empty list to store annotated images
annotated_images = []
# If results is a list, loop through each result
if isinstance(results, list):
for result in results:
# Plot detection results on the original image
annotated_image = result.plot()
annotated_images.append(annotated_image)
else:
# Plot detection results on the original image
annotated_image = results.plot()
annotated_images.append(annotated_image)
return annotated_images
except Exception as e:
st.error(f"An error occurred during object segmentation: {e}")
return [image] # Return original image in case of errors
# Function to perform object segmentation on video frames
def detect_objects_video(video_file, classes):
# Open the video file
video = cv2.VideoCapture(video_file.name)
if not video.isOpened():
st.error("Error: Unable to open video file.")
return
# Create a temporary file to store the annotated video
temp_video_file = NamedTemporaryFile(delete=False, suffix=".mp4")
# Get video properties
frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(video.get(cv2.CAP_PROP_FPS))
# Create VideoWriter object to save the annotated video
out = cv2.VideoWriter(temp_video_file.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
# Read until video is completed
while video.isOpened():
ret, frame = video.read()
if not ret:
break
# Convert frame to PIL Image
frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# Perform object detection on the frame
annotated_frame = detect_objects(frame_pil, classes)
# Convert annotated frame back to numpy array
annotated_frame_np = np.array(annotated_frame[0])
# Write the annotated frame to the output video
out.write(cv2.cvtColor(annotated_frame_np, cv2.COLOR_RGB2BGR))
# Release the video objects
video.release()
out.release()
return temp_video_file
# Streamlit app
st.title("YOLOv8 Object Segmentation")
# Upload image or video
uploaded_file = st.file_uploader("Upload Image or Video", type=["jpg", "jpeg", "png", "mp4"])
if uploaded_file is not None:
# Check if the uploaded file is a video
is_video = uploaded_file.name.endswith(".mp4")
if is_video:
# Perform object segmentation on video
st.write("Performing object segmentation on video...")
try:
detected_video = detect_objects_video(uploaded_file, classes=['COW', 'Cattle', 'horse', 'pig', 'sheep', 'undefined'])
st.video(detected_video.name)
except Exception as e:
st.error(f"An error occurred: {e}")
else:
# Perform object segmentation on image
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
st.write("Performing object segmentation...")
try:
detected_segmentation = detect_objects(image, classes=['COW', 'Cattle', 'horse', 'pig', 'sheep', 'undefined'])
for annotated_image in detected_segmentation:
st.image(annotated_image, caption='Segmentation Mask', use_column_width=True)
except Exception as e:
st.error(f"An error occurred: {e}")