File size: 3,043 Bytes
57faaa3 09e9254 ac43253 92143f3 614f379 57faaa3 614f379 57faaa3 ab9f499 18de1a7 614f379 09e9254 614f379 2c69990 3a2a8a3 614f379 ac43253 614f379 ac43253 614f379 2c69990 69afc3c 614f379 69afc3c 614f379 ac43253 614f379 ac43253 614f379 ac43253 614f379 ac43253 614f379 2c69990 614f379 2c69990 614f379 2c69990 614f379 2c69990 614f379 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | import numpy as np
import gradio as gr
from PIL import Image
import onnxruntime as ort
sess = ort.InferenceSession("mobilevit_qat.onnx")
inp_name = sess.get_inputs()[0].name
MEAN = np.float32([0.485, 0.456, 0.406])
STD = np.float32([0.229, 0.224, 0.225])
def predict(image):
if image is None:
return bar_html(50, 50)
img = Image.fromarray(image) if isinstance(image, np.ndarray) else image
img = img.convert("RGB").resize((224, 224))
x = np.float32(np.array(img)) / 255
x = (x - MEAN) / STD
x = x.transpose(2, 0, 1)[None, ...]
out = sess.run(None, {inp_name: x})[0].ravel()
p = np.exp(out - out.max())
p /= p.sum()
return bar_html(float(p[1]) * 100, float(p[0]) * 100)
def bar_html(person, bg):
label = "PERSON" if person > bg else "NO PERSON"
clr = "#4ade80" if person > bg else "#f87171"
return f"""
<div style="font-family:sans-serif;padding:18px">
<p style="text-align:center;font-size:1.4em;font-weight:700;color:{clr};margin:0 0 18px">{label}</p>
<div style="margin-bottom:12px">
<div style="display:flex;justify-content:space-between;margin-bottom:4px">
<span style="color:#ccc">Person</span>
<span style="color:#4ade80;font-weight:700;font-variant-numeric:tabular-nums">{person:.1f}%</span>
</div>
<div style="background:#1e293b;border-radius:8px;height:16px;overflow:hidden">
<div style="width:{person:.1f}%;height:100%;background:#4ade80;border-radius:8px;transition:width .3s"></div>
</div>
</div>
<div>
<div style="display:flex;justify-content:space-between;margin-bottom:4px">
<span style="color:#ccc">Background</span>
<span style="color:#94a3b8;font-weight:700;font-variant-numeric:tabular-nums">{bg:.1f}%</span>
</div>
<div style="background:#1e293b;border-radius:8px;height:16px;overflow:hidden">
<div style="width:{bg:.1f}%;height:100%;background:#64748b;border-radius:8px;transition:width .3s"></div>
</div>
</div>
</div>"""
with gr.Blocks(theme=gr.themes.Soft(), title="VWW MobileViT") as demo:
gr.Markdown("# Visual Wake Words — MobileViT-XXS\nPerson detection trained on COCO VWW · QAT pipeline · ONNX Runtime")
with gr.Tabs():
with gr.TabItem("Webcam"):
with gr.Row():
cam = gr.Image(sources=["webcam"], streaming=True, label="Feed")
cam_out = gr.HTML(bar_html(50, 50))
cam.stream(predict, cam, cam_out, time_limit=120)
with gr.TabItem("Upload"):
with gr.Row():
img_in = gr.Image(type="pil", label="Image")
img_out = gr.HTML(bar_html(50, 50))
img_in.change(predict, img_in, img_out)
with gr.Accordion("About", open=False):
gr.Markdown(
"MobileViT-XXS with quantization-aware training on Visual Wake Words (COCO 2014). "
"Conv2d layers trained with per-channel INT8 fake quantisation; transformer blocks kept at FP32. "
"Exported as ONNX. [Paper](https://arxiv.org/abs/2110.02178)"
)
demo.launch()
|