valentynliubchenko
first commit
211cb30
raw
history blame
1.4 kB
import os
import gradio as gr
from ultralytics import YOLO
import numpy as np
model = YOLO('./model/xViewyolov8m_v8_100e.pt')
example_list = [["examples/" + example] for example in os.listdir("examples")]
def process_image(input_image):
# results = model(input_image)
# results = model.predict(input_image, conf=0.6, classes=range(0, 78))
results = model.predict(input_image, conf=0.6)
class_counts = {}
class_counts_str = "Class Counts:\n"
for r in results:
im_array = r.plot()
im_array = im_array.astype(np.uint8)
for box in r.boxes:
class_name = r.names[box.cls[0].item()]
class_counts[class_name] = class_counts.get(class_name, 0) + 1
for cls, count in class_counts.items():
class_counts_str += f"\n{cls}: {count}"
return im_array, class_counts_str
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(),
outputs=["image", gr.Textbox(label="More info")],
title="YOLO Object detection. Trained on xView dataset. Medium model. Predict with conf=0.6",
description='''The xView dataset is composed of satellite images collected from WorldView-3 satellites at a 0.3m ground sample distance.\n
It contains over 1 million objects across 60 classes in over 1,400 km of imagery.''',
live=True,
examples=example_list
)
iface.launch()