import os import cv2 import gradio as gr from ultralytics import YOLO import numpy as np model_options = ["xViewyolov8n_v8_100e.pt", "xViewyolov8s_v8_100e.pt", "xViewyolov8m_v8_100e.pt"] model_names = ["Nano", "Small", "Medium"] print("before loading examples") example_list = [["./examples/" + example] for example in os.listdir("examples")] print("before loading models") models = [YOLO(os.path.join("./saved_model", option)) for option in model_options] print("finish preparation") def process_image(input_image, model_name, conf): print('start processing: ') if input_image is None: return None, "No objects detected." input_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR) print(f"model_name : {model_name}") print(f"conf : {conf}") if model_name is None: model_name = model_names[0] if conf is None: conf = 0.6 model_index = model_names.index(model_name) print('model_index: ') print(model_index) model = models[model_index] results = model.predict(input_image, conf=conf) class_counts = {} class_counts_str = "Class Counts:\n" for r in results: im_array = r.plot() im_array = im_array.astype(np.uint8) im_array = cv2.cvtColor(im_array, cv2.COLOR_BGR2RGB) if len(r.boxes) == 0: # If no objects are detected return None, "No objects detected." for box in r.boxes: class_name = r.names[box.cls[0].item()] class_counts[class_name] = class_counts.get(class_name, 0) + 1 for cls, count in class_counts.items(): class_counts_str += f"\n{cls}: {count}" return im_array, class_counts_str iface = gr.Interface( fn=process_image, inputs=[ gr.Image(), gr.Radio(model_names, label="Choose model", value=model_names[0]), gr.Slider(minimum=0.2, maximum=1.0, step=0.1, label="Confidence Threshold", value=0.6) ], outputs=["image", gr.Textbox(label="More info")], title="YOLO Object detection. Trained on xView dataset.", description='''The xView dataset is composed of satellite images collected from WorldView-3 satellites at a 0.3m ground sample distance.\n It contains over 1 million objects across 60 classes in over 1,400 km of imagery. https://challenge.xviewdataset.org ''', live=True, examples=example_list ) iface.launch()