File size: 3,634 Bytes
3aad127
79b792e
 
3072e1b
3aad127
799c02f
3aad127
3072e1b
3aad127
 
79b792e
 
 
 
 
 
 
 
 
 
73f6108
 
 
 
79b792e
 
799c02f
 
 
3aad127
799c02f
79b792e
 
3aad127
799c02f
73f6108
799c02f
3aad127
799c02f
 
73f6108
 
d00fd73
79b792e
 
 
d00fd73
3aad127
73f6108
 
799c02f
73f6108
 
 
799c02f
73f6108
 
799c02f
 
73f6108
 
799c02f
 
 
73f6108
799c02f
3aad127
79b792e
 
 
 
 
3aad127
 
 
 
73f6108
799c02f
 
79b792e
 
 
 
 
 
 
 
 
799c02f
 
 
79b792e
799c02f
 
 
 
79b792e
 
 
 
 
 
 
 
 
 
799c02f
 
 
3aad127
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# -*- coding: utf-8 -*-
import time

import cv2
import gradio as gr
import numpy as np

from scoutbot import loc, wic


def predict(filepath, config, wic_thresh, loc_thresh, nms_thresh):
    start = time.time()

    if config == 'MVP':
        config = 'mvp'
    elif config == 'Phase 1':
        config = 'phase1'
    else:
        raise ValueError()

    wic_thresh /= 100.0
    loc_thresh /= 100.0
    nms_thresh /= 100.0

    nms_thresh = 1.0 - nms_thresh

    # Load data
    img = cv2.imread(filepath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Run WIC
    inputs = [filepath]
    outputs = wic.post(wic.predict(wic.pre(inputs, config=config)))

    # Get WIC confidence
    output = outputs[0]
    wic_confidence = output.get('positive')

    loc_detections = []
    if wic_confidence > wic_thresh:

        # Run Localizer
        outputs = loc.post(
            loc.predict(loc.pre(inputs, config=config)),
            loc_thresh=loc_thresh,
            nms_thresh=nms_thresh,
        )

        # Format and render results
        detects = outputs[0]
        for detect in detects:
            label = detect['l']
            conf = detect['c']
            if conf >= loc_thresh:
                point1 = (
                    int(np.around(detect['x'])),
                    int(np.around(detect['y'])),
                )
                point2 = (
                    int(np.around(detect['x'] + detect['w'])),
                    int(np.around(detect['y'] + detect['h'])),
                )
                color = (255, 0, 0)
                img = cv2.rectangle(img, point1, point2, color, 2)
                loc_detections.append(f'{label}: {conf:0.04f}')
    loc_detections = '\n'.join(loc_detections)

    end = time.time()
    duration = end - start
    speed = f'{duration:0.02f} seconds)'

    return img, speed, wic_confidence, loc_detections


interface = gr.Interface(
    fn=predict,
    title='Wild Me Scout - Tile ML Demo',
    inputs=[
        gr.Image(type='filepath'),
        gr.Radio(
            label='Model Configuration',
            type='value',
            choices=['Phase 1', 'MVP'],
            value='MVP',
        ),
        gr.Slider(label='WIC Confidence Threshold', value=7),
        gr.Slider(label='Localizer Confidence Threshold', value=14),
        gr.Slider(label='Localizer NMS Threshold', value=80),
    ],
    outputs=[
        gr.Image(type='numpy'),
        gr.Textbox(label='Prediction Speed', interactive=False),
        gr.Number(label='Predicted WIC Confidence', precision=5, interactive=False),
        gr.Textbox(label='Predicted Localizer Detections', interactive=False),
    ],
    examples=[
        ['examples/07a4b8db-f31c-261d-4580-e9402768fd45.true.jpg', 'MVP', 7, 14, 80],
        ['examples/15e815d9-5aad-fa53-d1ed-33429020e15e.true.jpg', 'MVP', 7, 14, 80],
        ['examples/1bb79811-3149-7a60-2d88-613dc3eeb261.true.jpg', 'MVP', 7, 14, 80],
        ['examples/1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg', 'MVP', 7, 14, 80],
        ['examples/201bc65e-d64e-80d3-2610-5865a22d04b4.false.jpg', 'MVP', 7, 14, 80],
        ['examples/3affd8b6-9722-f2d5-9171-639615b4c38f.true.jpg', 'MVP', 7, 14, 80],
        ['examples/4aedb818-f2f4-e462-8b75-5c8e34a01a59.false.jpg', 'MVP', 7, 14, 80],
        ['examples/474bc2b6-dc51-c1b5-4612-efe810bbe091.true.jpg', 'MVP', 7, 14, 80],
        ['examples/c3014107-3464-60b5-e04a-e4bfafdf8809.false.jpg', 'MVP', 7, 14, 80],
        ['examples/f835ce33-292a-9116-794e-f8859b5956ec.true.jpg', 'MVP', 7, 14, 80],
    ],
    cache_examples=True,
    allow_flagging='never',
)

interface.launch(server_name='0.0.0.0')