File size: 2,431 Bytes
211cb30
1929ba0
211cb30
 
 
 
44068da
e2717de
0fe0060
f122098
0fe0060
7a9e429
 
211cb30
0c6b368
7a9e429
2b9ca25
 
 
6729627
2b9ca25
 
2ee677b
 
0fe0060
2ee677b
 
0fe0060
2ee677b
44068da
7a9e429
 
44068da
 
0c6b368
211cb30
 
 
 
 
 
6729627
211cb30
02e3ba7
 
 
211cb30
 
 
 
 
 
 
 
 
 
 
0c6b368
 
 
 
 
211cb30
44068da
211cb30
72b17a3
 
211cb30
 
0c6b368
211cb30
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import cv2
import gradio as gr
from ultralytics import YOLO
import numpy as np

model_options = ["xViewyolov8n_v8_100e.pt", "xViewyolov8s_v8_100e.pt", "xViewyolov8m_v8_100e.pt"]
model_names = ["Nano", "Small", "Medium"]
print("before loading examples")
example_list = [["./examples/" + example] for example in os.listdir("examples")]
print("before loading models")
models = [YOLO(os.path.join("./saved_model", option)) for option in model_options]
print("finish preparation")

def process_image(input_image, model_name, conf):
    print('start processing: ')
    if input_image is None:
        return None, "No objects detected."

    input_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
    print(f"model_name : {model_name}")
    print(f"conf : {conf}")

    if model_name is None:
        model_name = model_names[0]

    if conf is None:
        conf = 0.6

    model_index = model_names.index(model_name)
    print('model_index: ')
    print(model_index)
    model = models[model_index]
    
    results = model.predict(input_image, conf=conf)
    class_counts = {}
    class_counts_str = "Class Counts:\n"

    for r in results:
        im_array = r.plot()
        im_array = im_array.astype(np.uint8)
        im_array = cv2.cvtColor(im_array, cv2.COLOR_BGR2RGB)

        if len(r.boxes) == 0:  # If no objects are detected
            return None, "No objects detected."

        for box in r.boxes:
            class_name = r.names[box.cls[0].item()]
            class_counts[class_name] = class_counts.get(class_name, 0) + 1

        for cls, count in class_counts.items():
            class_counts_str += f"\n{cls}: {count}"

        return im_array, class_counts_str

iface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(), 
        gr.Radio(model_names, label="Choose model", value=model_names[0]),
        gr.Slider(minimum=0.2, maximum=1.0, step=0.1, label="Confidence Threshold", value=0.6)
    ], 
    outputs=["image", gr.Textbox(label="More info")],
    title="YOLO Object detection. Trained on xView dataset.",
    description='''The xView dataset is composed of satellite images collected from WorldView-3 satellites at a 0.3m ground sample distance.\n 
                   It contains over 1 million objects across 60 classes in over 1,400 km of imagery.
                   https://challenge.xviewdataset.org ''',
    live=True,
    examples=example_list
)

iface.launch()