Marolahy's picture
Fix typo
ab90cf1
import gradio as gr
import cv2
from ultralytics import YOLO
# Function to process and display predictions on images
def process_image(image_path):
img = cv2.imread(image_path)
# Load the YOLO model
yolo_model = YOLO("best.pt")
predictions = yolo_model.predict(source=image_path)
results = predictions[0].cpu().numpy()
# Extracting bounding boxes and class names
boxes = results.boxes
class_names = yolo_model.model.names
for box, confidence, class_id in zip(boxes.xyxy, boxes.conf, boxes.cls):
x1, y1, x2, y2 = map(int, box)
label = class_names[int(class_id)]
color = (0, 0, 255) if label.lower() == "ripe_tomato" else (0, 255, 0)
# Draw bounding box and label
cv2.rectangle(img, (x1, y1), (x2, y2), color, 1)
text = f"ripe {confidence:.2f}" if label.lower() == "ripe_tomato" else f"unripe {confidence:.2f}"
(text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
cv2.rectangle(img, (x1, y1 - text_height - baseline), (x1 + text_width, y1), color, -1)
cv2.putText(img, text, (x1, y1 - baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1, cv2.LINE_AA)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Function to process and display predictions on videos
def process_video(video_path):
video_capture = cv2.VideoCapture(video_path)
# Load the YOLO model
yolo_model = YOLO("best.pt")
while video_capture.isOpened():
ret, frame = video_capture.read()
if ret:
frame_copy = frame.copy()
predictions = yolo_model.predict(source=frame)
results = predictions[0].cpu().numpy()
boxes = results.boxes
class_names = yolo_model.model.names
for box, confidence, class_id in zip(boxes.xyxy, boxes.conf, boxes.cls):
x1, y1, x2, y2 = map(int, box)
label = class_names[int(class_id)]
color = (0, 0, 255) if label.lower() == "ripe_tomato" else (0, 255, 0)
# Draw bounding box and label
cv2.rectangle(frame_copy, (x1, y1), (x2, y2), color, 1)
text = f"ripe {confidence:.2f}" if label.lower() == "ripe_tomato" else f"unripe {confidence:.2f}"
(text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
cv2.rectangle(frame_copy, (x1, y1 - text_height - baseline), (x1 + text_width, y1), color, -1)
cv2.putText(frame_copy, text, (x1, y1 - baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1, cv2.LINE_AA)
yield cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
else:
break
video_capture.release()
# Gradio interface for image input
image_interface = gr.Interface(
fn=process_image,
inputs=[gr.components.Image(type="filepath", label="Select Image")],
outputs=[gr.components.Image(type="numpy", label="Processed Image")],
title="Tomato Ripeness Detection",
examples=["samples/image_0.jpg", "samples/image_1.jpg", "samples/image_2.jpg"],
cache_examples=False,
)
video_interface = gr.Interface(
fn=process_video,
inputs=[gr.components.Video(label="Select Video")],
outputs=[gr.components.Image(type="numpy", label="Processed Frame")],
title="Tomato Ripeness Detection in Video",
examples=["samples/video.mp4"],
cache_examples=False,
)
if __name__ == "__main__":
# Create a tabbed interface for both image and video processing
gr.TabbedInterface(
[image_interface, video_interface],
tab_names=["Image Detection", "Video Detection"]
).queue().launch()