gui-element-classifier / inference_example.py
diogoneno's picture
initial release: 15-class MobileNetV3-small GUI element classifier
05f3c62 verified
Raw
History Blame Contribute Delete
3.82 kB
"""Self-contained inference example for the 15-class UI-element classifier.
Run:
pip install onnxruntime numpy pillow
python inference_example.py path/to/element_crop.png
Designed to drop into an LLM-orchestration loop where you have a screenshot,
a list of detected element bounding boxes (from any detector β€” YOLOv8, OWL-ViT,
SAM-then-filter, accessibility tree, etc.), and you need cheap, deterministic
per-element type labels before passing them to a reasoning LLM.
Inference is CPU-friendly (~5 ms per crop on a modern x86 laptop). Use it as a
'helper' that adds structure to the orchestrator's prompt β€” e.g., 'click the
text_input near label "Username"' β€” instead of paying VLM tokens to look at
every crop.
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
import numpy as np
import onnxruntime as ort
from PIL import Image
HERE = Path(__file__).parent
ONNX_PATH = HERE / "mobilenetv3_small.onnx"
CLASSES = json.loads((HERE / "classes.json").read_text())["classes"]
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
def pad_to_square(img: Image.Image) -> Image.Image:
"""Pad shorter side with gray (128, 128, 128) β€” must match training transform."""
w, h = img.size
m = max(w, h)
out = Image.new("RGB", (m, m), (128, 128, 128))
out.paste(img, ((m - w) // 2, (m - h) // 2))
return out
def preprocess(img: Image.Image) -> np.ndarray:
"""PadToSquare -> Resize 224x224 BILINEAR -> /255 -> ImageNet normalize -> CHW."""
img = pad_to_square(img.convert("RGB"))
img = img.resize((224, 224), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0
arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
arr = arr.transpose(2, 0, 1)
return arr[None, :, :, :].astype(np.float32) # (1, 3, 224, 224)
def softmax(x: np.ndarray) -> np.ndarray:
e = np.exp(x - np.max(x, axis=1, keepdims=True))
return e / np.sum(e, axis=1, keepdims=True)
def classify(crop_path: str | Path) -> dict:
"""Classify a single element crop. Returns label, confidence, full score map."""
so = ort.SessionOptions()
so.intra_op_num_threads = 4
sess = ort.InferenceSession(str(ONNX_PATH), sess_options=so, providers=["CPUExecutionProvider"])
img = Image.open(crop_path)
batch = preprocess(img)
logits = sess.run(None, {sess.get_inputs()[0].name: batch})[0]
probs = softmax(logits)[0]
idx = int(np.argmax(probs))
return {
"label": CLASSES[idx],
"confidence": float(probs[idx]),
"scores": {c: float(probs[i]) for i, c in enumerate(CLASSES)},
}
def classify_batch(crop_paths: list[str | Path]) -> list[dict]:
"""Convenience: per-crop loop. The shipped ONNX is fixed batch_size=1.
For higher throughput on large batches, re-export with dynamic axes and
run a single batched session.run() β€” kept simple here for clarity.
"""
so = ort.SessionOptions()
so.intra_op_num_threads = 4
sess = ort.InferenceSession(str(ONNX_PATH), sess_options=so, providers=["CPUExecutionProvider"])
results = []
for p in crop_paths:
img = Image.open(p)
batch = preprocess(img)
logits = sess.run(None, {sess.get_inputs()[0].name: batch})[0]
probs = softmax(logits)[0]
idx = int(np.argmax(probs))
results.append({
"label": CLASSES[idx],
"confidence": float(probs[idx]),
})
return results
if __name__ == "__main__":
if len(sys.argv) < 2:
print("usage: python inference_example.py <crop.png> [<crop2.png> ...]")
sys.exit(1)
for path in sys.argv[1:]:
result = classify(path)
print(f"{path}: {result['label']} (confidence={result['confidence']:.3f})")