| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser |
| from pathlib import Path |
|
|
| import gradio as gr |
| import numpy as np |
| import onnxruntime |
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
|
|
|
|
| def main(args): |
| |
| with Path(hf_hub_download(args.repo, filename="config.json")).open("rb") as f: |
| cfg = json.load(f) |
|
|
| ort_session = onnxruntime.InferenceSession(hf_hub_download(args.repo, filename="model.onnx")) |
|
|
| def preprocess_image(pil_img: Image.Image) -> np.ndarray: |
| """Preprocess an image for inference |
| |
| Args: |
| pil_img: a valid pillow image |
| |
| Returns: |
| the resized and normalized image of shape (1, C, H, W) |
| """ |
| |
| img = pil_img.resize(cfg["input_shape"][-2:][::-1], Image.BILINEAR) |
| |
| img = np.asarray(img).transpose((2, 0, 1)).astype(np.float32) / 255 |
| |
| img -= np.array(cfg["mean"])[:, None, None] |
| img /= np.array(cfg["std"])[:, None, None] |
|
|
| return img[None, ...] |
|
|
| def predict(image): |
| |
| np_img = preprocess_image(image) |
| ort_input = {ort_session.get_inputs()[0].name: np_img} |
|
|
| |
| ort_out = ort_session.run(None, ort_input) |
| |
| out_exp = np.exp(ort_out[0][0]) |
| probs = out_exp / out_exp.sum() |
|
|
| return {class_name: float(conf) for class_name, conf in zip(cfg["classes"], probs, strict=True)} |
|
|
| interface = gr.Interface( |
| fn=predict, |
| inputs=gr.Image(type="pil"), |
| outputs=gr.Label(num_top_classes=3), |
| title="Holocron: image classification demo", |
| article=( |
| "<p style='text-align: center'><a href='https://github.com/frgfm/holocron'>" |
| "Github Repo</a> | " |
| "<a href='https://frgfm.github.io/holocron/'>Documentation</a></p>" |
| ), |
| live=True, |
| ) |
|
|
| interface.launch(server_port=args.port, show_error=True) |
|
|
|
|
| if __name__ == "__main__": |
| parser = ArgumentParser( |
| description="Holocron image classification demo", formatter_class=ArgumentDefaultsHelpFormatter |
| ) |
| parser.add_argument("--repo", type=str, default="frgfm/rexnet1_0x", help="HF Hub repo to use") |
| parser.add_argument("--port", type=int, default=None, help="Port on which the webserver will be run") |
| args = parser.parse_args() |
|
|
| main(args) |
|
|