File size: 3,043 Bytes
57faaa3
 
09e9254
ac43253
92143f3
614f379
 
57faaa3
614f379
 
57faaa3
ab9f499
18de1a7
 
614f379
09e9254
614f379
2c69990
3a2a8a3
614f379
 
 
ac43253
614f379
 
 
ac43253
614f379
2c69990
69afc3c
614f379
 
 
69afc3c
614f379
 
 
 
 
 
ac43253
614f379
 
ac43253
614f379
 
 
 
 
ac43253
614f379
 
 
 
ac43253
 
 
614f379
 
2c69990
 
614f379
2c69990
614f379
 
 
2c69990
614f379
2c69990
614f379
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
import gradio as gr
from PIL import Image
import onnxruntime as ort

sess = ort.InferenceSession("mobilevit_qat.onnx")
inp_name = sess.get_inputs()[0].name

MEAN = np.float32([0.485, 0.456, 0.406])
STD = np.float32([0.229, 0.224, 0.225])


def predict(image):
    if image is None:
        return bar_html(50, 50)

    img = Image.fromarray(image) if isinstance(image, np.ndarray) else image
    img = img.convert("RGB").resize((224, 224))

    x = np.float32(np.array(img)) / 255
    x = (x - MEAN) / STD
    x = x.transpose(2, 0, 1)[None, ...]

    out = sess.run(None, {inp_name: x})[0].ravel()
    p = np.exp(out - out.max())
    p /= p.sum()

    return bar_html(float(p[1]) * 100, float(p[0]) * 100)


def bar_html(person, bg):
    label = "PERSON" if person > bg else "NO PERSON"
    clr = "#4ade80" if person > bg else "#f87171"
    return f"""
<div style="font-family:sans-serif;padding:18px">
 <p style="text-align:center;font-size:1.4em;font-weight:700;color:{clr};margin:0 0 18px">{label}</p>
 <div style="margin-bottom:12px">
  <div style="display:flex;justify-content:space-between;margin-bottom:4px">
   <span style="color:#ccc">Person</span>
   <span style="color:#4ade80;font-weight:700;font-variant-numeric:tabular-nums">{person:.1f}%</span>
  </div>
  <div style="background:#1e293b;border-radius:8px;height:16px;overflow:hidden">
   <div style="width:{person:.1f}%;height:100%;background:#4ade80;border-radius:8px;transition:width .3s"></div>
  </div>
 </div>
 <div>
  <div style="display:flex;justify-content:space-between;margin-bottom:4px">
   <span style="color:#ccc">Background</span>
   <span style="color:#94a3b8;font-weight:700;font-variant-numeric:tabular-nums">{bg:.1f}%</span>
  </div>
  <div style="background:#1e293b;border-radius:8px;height:16px;overflow:hidden">
   <div style="width:{bg:.1f}%;height:100%;background:#64748b;border-radius:8px;transition:width .3s"></div>
  </div>
 </div>
</div>"""


with gr.Blocks(theme=gr.themes.Soft(), title="VWW MobileViT") as demo:
    gr.Markdown("# Visual Wake Words — MobileViT-XXS\nPerson detection trained on COCO VWW &middot; QAT pipeline &middot; ONNX Runtime")

    with gr.Tabs():
        with gr.TabItem("Webcam"):
            with gr.Row():
                cam = gr.Image(sources=["webcam"], streaming=True, label="Feed")
                cam_out = gr.HTML(bar_html(50, 50))
            cam.stream(predict, cam, cam_out, time_limit=120)

        with gr.TabItem("Upload"):
            with gr.Row():
                img_in = gr.Image(type="pil", label="Image")
                img_out = gr.HTML(bar_html(50, 50))
            img_in.change(predict, img_in, img_out)

    with gr.Accordion("About", open=False):
        gr.Markdown(
            "MobileViT-XXS with quantization-aware training on Visual Wake Words (COCO 2014). "
            "Conv2d layers trained with per-channel INT8 fake quantisation; transformer blocks kept at FP32. "
            "Exported as ONNX. [Paper](https://arxiv.org/abs/2110.02178)"
        )

demo.launch()