scoutbot / app.py
bluemellophone's picture
Add MVP localizer, quiet logging to stdout, config aliasing, and updated documentation
79b792e unverified
# -*- coding: utf-8 -*-
import time
import cv2
import gradio as gr
import numpy as np
from scoutbot import loc, wic
def predict(filepath, config, wic_thresh, loc_thresh, nms_thresh):
start = time.time()
if config == 'MVP':
config = 'mvp'
elif config == 'Phase 1':
config = 'phase1'
else:
raise ValueError()
wic_thresh /= 100.0
loc_thresh /= 100.0
nms_thresh /= 100.0
nms_thresh = 1.0 - nms_thresh
# Load data
img = cv2.imread(filepath)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Run WIC
inputs = [filepath]
outputs = wic.post(wic.predict(wic.pre(inputs, config=config)))
# Get WIC confidence
output = outputs[0]
wic_confidence = output.get('positive')
loc_detections = []
if wic_confidence > wic_thresh:
# Run Localizer
outputs = loc.post(
loc.predict(loc.pre(inputs, config=config)),
loc_thresh=loc_thresh,
nms_thresh=nms_thresh,
)
# Format and render results
detects = outputs[0]
for detect in detects:
label = detect['l']
conf = detect['c']
if conf >= loc_thresh:
point1 = (
int(np.around(detect['x'])),
int(np.around(detect['y'])),
)
point2 = (
int(np.around(detect['x'] + detect['w'])),
int(np.around(detect['y'] + detect['h'])),
)
color = (255, 0, 0)
img = cv2.rectangle(img, point1, point2, color, 2)
loc_detections.append(f'{label}: {conf:0.04f}')
loc_detections = '\n'.join(loc_detections)
end = time.time()
duration = end - start
speed = f'{duration:0.02f} seconds)'
return img, speed, wic_confidence, loc_detections
interface = gr.Interface(
fn=predict,
title='Wild Me Scout - Tile ML Demo',
inputs=[
gr.Image(type='filepath'),
gr.Radio(
label='Model Configuration',
type='value',
choices=['Phase 1', 'MVP'],
value='MVP',
),
gr.Slider(label='WIC Confidence Threshold', value=7),
gr.Slider(label='Localizer Confidence Threshold', value=14),
gr.Slider(label='Localizer NMS Threshold', value=80),
],
outputs=[
gr.Image(type='numpy'),
gr.Textbox(label='Prediction Speed', interactive=False),
gr.Number(label='Predicted WIC Confidence', precision=5, interactive=False),
gr.Textbox(label='Predicted Localizer Detections', interactive=False),
],
examples=[
['examples/07a4b8db-f31c-261d-4580-e9402768fd45.true.jpg', 'MVP', 7, 14, 80],
['examples/15e815d9-5aad-fa53-d1ed-33429020e15e.true.jpg', 'MVP', 7, 14, 80],
['examples/1bb79811-3149-7a60-2d88-613dc3eeb261.true.jpg', 'MVP', 7, 14, 80],
['examples/1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg', 'MVP', 7, 14, 80],
['examples/201bc65e-d64e-80d3-2610-5865a22d04b4.false.jpg', 'MVP', 7, 14, 80],
['examples/3affd8b6-9722-f2d5-9171-639615b4c38f.true.jpg', 'MVP', 7, 14, 80],
['examples/4aedb818-f2f4-e462-8b75-5c8e34a01a59.false.jpg', 'MVP', 7, 14, 80],
['examples/474bc2b6-dc51-c1b5-4612-efe810bbe091.true.jpg', 'MVP', 7, 14, 80],
['examples/c3014107-3464-60b5-e04a-e4bfafdf8809.false.jpg', 'MVP', 7, 14, 80],
['examples/f835ce33-292a-9116-794e-f8859b5956ec.true.jpg', 'MVP', 7, 14, 80],
],
cache_examples=True,
allow_flagging='never',
)
interface.launch(server_name='0.0.0.0')