import numpy as np
import gradio as gr
from PIL import Image
import onnxruntime as ort
sess = ort.InferenceSession("mobilevit_qat.onnx")
inp_name = sess.get_inputs()[0].name
MEAN = np.float32([0.485, 0.456, 0.406])
STD = np.float32([0.229, 0.224, 0.225])
def predict(image):
if image is None:
return bar_html(50, 50)
img = Image.fromarray(image) if isinstance(image, np.ndarray) else image
img = img.convert("RGB").resize((224, 224))
x = np.float32(np.array(img)) / 255
x = (x - MEAN) / STD
x = x.transpose(2, 0, 1)[None, ...]
out = sess.run(None, {inp_name: x})[0].ravel()
p = np.exp(out - out.max())
p /= p.sum()
return bar_html(float(p[1]) * 100, float(p[0]) * 100)
def bar_html(person, bg):
label = "PERSON" if person > bg else "NO PERSON"
clr = "#4ade80" if person > bg else "#f87171"
return f"""
"""
with gr.Blocks(theme=gr.themes.Soft(), title="VWW MobileViT") as demo:
gr.Markdown("# Visual Wake Words — MobileViT-XXS\nPerson detection trained on COCO VWW · QAT pipeline · ONNX Runtime")
with gr.Tabs():
with gr.TabItem("Webcam"):
with gr.Row():
cam = gr.Image(sources=["webcam"], streaming=True, label="Feed")
cam_out = gr.HTML(bar_html(50, 50))
cam.stream(predict, cam, cam_out, time_limit=120)
with gr.TabItem("Upload"):
with gr.Row():
img_in = gr.Image(type="pil", label="Image")
img_out = gr.HTML(bar_html(50, 50))
img_in.change(predict, img_in, img_out)
with gr.Accordion("About", open=False):
gr.Markdown(
"MobileViT-XXS with quantization-aware training on Visual Wake Words (COCO 2014). "
"Conv2d layers trained with per-channel INT8 fake quantisation; transformer blocks kept at FP32. "
"Exported as ONNX. [Paper](https://arxiv.org/abs/2110.02178)"
)
demo.launch()