MinhLe999 commited on
Commit
4d1fd74
·
0 Parent(s):

Gradio ONNX app

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +78 -0
  3. mobilenetv3_binary_merged.onnx +3 -0
  4. requirements.txt +8 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.onnx filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ from PIL import Image
4
+ import gradio as gr
5
+ from torchvision import transforms
6
+
7
+
8
+ # ---- config ----
9
+ MODEL_PATH = "mobilenetv3_binary_merged.onnx"
10
+ IMG_SIZE = 256
11
+ to_rgb = transforms.Lambda(lambda img: img.convert("RGB"))
12
+
13
+ # ---- ONNX session ----
14
+ sess = ort.InferenceSession(
15
+ MODEL_PATH,
16
+ providers=["CPUExecutionProvider"]
17
+ )
18
+
19
+ input_name = sess.get_inputs()[0].name
20
+ output_name = sess.get_outputs()[0].name
21
+
22
+
23
+ # ---- preprocessing ----
24
+ preprocess = transforms.Compose([
25
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
26
+ to_rgb,
27
+ transforms.ToTensor(),
28
+ transforms.Normalize(
29
+ mean=[0.485, 0.456, 0.406],
30
+ std=[0.229, 0.224, 0.225],
31
+ ),
32
+ ])
33
+
34
+
35
+ def sigmoid(x):
36
+ return 1.0 / (1.0 + np.exp(-x))
37
+
38
+
39
+ def predict(image: Image.Image):
40
+ if image is None:
41
+ return 0.0, 0.0, 0
42
+
43
+ image = image.convert("RGB")
44
+ x = preprocess(image).unsqueeze(0).numpy().astype(np.float32)
45
+
46
+ logits = sess.run(
47
+ [output_name],
48
+ {input_name: x}
49
+ )[0]
50
+
51
+ logit = float(logits.reshape(-1)[0])
52
+ prob = float(sigmoid(logit))
53
+ pred = int(prob > 0.5)
54
+
55
+ return logit, prob, pred
56
+
57
+
58
+ # ---- Gradio UI ----
59
+ with gr.Blocks() as demo:
60
+ gr.Markdown("## MobileNetV3 Handwriting Binary Classifier (ONNX)")
61
+
62
+ inp = gr.Image(type="pil", label="Input image")
63
+ btn = gr.Button("Run inference")
64
+
65
+ out_logit = gr.Number(label="Logit")
66
+ out_prob = gr.Number(label="Probability")
67
+ out_pred = gr.Number(label="Prediction (0/1)")
68
+
69
+ btn.click(
70
+ fn=predict,
71
+ inputs=inp,
72
+ outputs=[out_logit, out_prob, out_pred],
73
+ api_name=False, # ADD THIS LINE
74
+ )
75
+
76
+
77
+ if __name__ == "__main__":
78
+ demo.launch() # Remove share=True and show_api parameters
mobilenetv3_binary_merged.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1597ee2fe5bd5565ba5914abfd0de8e2fa4828fb5c4001311a16b26942b58206
3
+ size 17146607
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=5.0.0
2
+ huggingface_hub<1.0
3
+ onnxruntime
4
+ torch
5
+ torchvision
6
+ numpy
7
+ pillow
8
+ pydantic<2.11