barath19's picture
update font
349b8ae verified
import os
import cv2
import requests
import gradio as gr
from ultralytics import YOLO
file_urls = ["https://www.kp-glas.de/wp-content/uploads/2023/02/Bierflaschen.jpeg",
"https://www.bevindustry.com/ext/resources/issues/2020/February/Heineken-0-0-Alcohol-Free-Beer.jpg?1581450636",
"https://spice-world.co.za/cdn/shop/products/3520_648ad722263c64.28810246_drinks_20copy_1db265d5-e818-45d3-973b-3e046a6b0931_800x.jpg?v=1697124785",
"https://i.redd.it/0e2zml4mwz021.jpg"]
def download_file(url, save_name):
if not os.path.exists(save_name):
file = requests.get(url)
open(save_name, 'wb').write(file.content)
for i, url in enumerate(file_urls):
download_file(file_urls[i], f"image_{i}.jpg")
# Available model options
model_paths = [
("Yolov8n (Accurate)","best.pt"),
("Yolov8n half-precision (Faster)","best.onnx")
]
# Paths to example images
path = [['image_0.jpg'], ['image_1.jpg'], ['image_2.jpg'], ['image_3.jpg']]
# Function to run inference and show predictions
def show_preds_image(image_path, selected_model):
# Load the selected model
model = YOLO(selected_model)
# Read the image
image = cv2.imread(image_path)
# Run model prediction
outputs = model.predict(source=image_path, conf=0.8)
results = outputs[0].cpu().numpy()
font_scale = 2
# Draw bounding boxes and labels
for i, det in enumerate(results.boxes.xyxy):
cv2.rectangle(image, (int(det[0]), int(det[1])), (int(det[2]), int(det[3])), color=(0, 0, 255), thickness=2, lineType=cv2.LINE_AA)
class_id = int(results.boxes.cls[i])
class_name = model.names[class_id] # Assuming model.names contains class names
label = f"{class_name}"
bbox_width = int(det[2] - det[0])
label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 2)
label_x1 = int(det[0])
label_y1 = max(int(det[1]), label_size[1] + 10) - label_size[1]
label_x2 = label_x1 + label_size[0]
label_y2 = label_y1 + label_size[1] + 5
cv2.rectangle(image, (label_x1, label_y1), (label_x2, label_y2), (255, 255, 255), thickness=-1)
cv2.putText(image, label, (label_x1, label_y2 - 5), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), 2, lineType=cv2.LINE_AA)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Gradio UI components
inputs_image = [
gr.components.Image(type="filepath", label="Input Image"),
gr.components.Dropdown(choices=model_paths, label="Select Model", value='best.pt', interactive=True) # Dropdown for model selection
]
outputs_image = [
gr.components.Image(type="numpy", label="Output Image")
]
# Gradio Interface
interface_image = gr.Interface(
fn=show_preds_image,
inputs=inputs_image,
outputs=outputs_image,
title="Beverage Container Detector",
examples=path,
cache_examples=False,
)
# Launch Gradio tabbed interface
gr.TabbedInterface(
[interface_image],
tab_names=['Image inference']
).queue().launch()