HandwritingOld / app.py
MinhLe999's picture
Push Claude fix 3
344495c
import onnxruntime as ort
import numpy as np
from PIL import Image
import gradio as gr
from torchvision import transforms
# ---- config ----
MODEL_PATH = "mobilenetv3_binary_merged.onnx"
IMG_SIZE = 256
# ---- ONNX session ----
sess = ort.InferenceSession(
MODEL_PATH,
providers=["CPUExecutionProvider"]
)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
# ---- preprocessing ----
preprocess = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-x))
def predict(image: Image.Image):
if image is None:
return 0.0, 0.0, 0
image = image.convert("RGB")
x = preprocess(image).unsqueeze(0).numpy().astype(np.float32)
logits = sess.run(
[output_name],
{input_name: x}
)[0]
logit = float(logits.reshape(-1)[0])
prob = float(sigmoid(logit))
pred = int(prob > 0.5)
return logit, prob, pred
# ---- Gradio UI ----
with gr.Blocks() as demo:
gr.Markdown("## MobileNetV3 Handwriting Binary Classifier (ONNX)")
inp = gr.Image(type="pil", label="Input image")
btn = gr.Button("Run inference")
out_logit = gr.Number(label="Logit")
out_prob = gr.Number(label="Probability")
out_pred = gr.Number(label="Prediction (0/1)")
btn.click(
fn=predict,
inputs=inp,
outputs=[out_logit, out_prob, out_pred],
api_name=False, # ADD THIS LINE
)
if __name__ == "__main__":
demo.launch() # Remove share=True and show_api parameters