Spaces:
Runtime error
Runtime error
| # Copyright (C) 2022, Pyronear. | |
| # This program is licensed under the Apache License 2.0. | |
| # See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details. | |
| import argparse | |
| import json | |
| import gradio as gr | |
| import numpy as np | |
| import onnxruntime | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| REPO = "pyronear/rexnet1_0x" | |
| # Download model config & checkpoint | |
| with open(hf_hub_download(REPO, filename="config.json"), "rb") as f: | |
| cfg = json.load(f) | |
| ort_session = onnxruntime.InferenceSession(hf_hub_download(REPO, filename="model.onnx")) | |
| def preprocess_image(pil_img: Image.Image) -> np.ndarray: | |
| """Preprocess an image for inference | |
| Args: | |
| pil_img: a valid pillow image | |
| Returns: | |
| the resized and normalized image of shape (1, C, H, W) | |
| """ | |
| # Resizing (PIL takes (W, H) order for resizing) | |
| img = pil_img.resize(cfg["input_shape"][-2:][::-1], Image.BILINEAR) | |
| # (H, W, C) --> (C, H, W) | |
| img = np.asarray(img).transpose((2, 0, 1)).astype(np.float32) / 255 | |
| # Normalization | |
| img -= np.array(cfg["mean"])[:, None, None] | |
| img /= np.array(cfg["std"])[:, None, None] | |
| return img[None, ...] | |
| def predict(image): | |
| # Preprocessing | |
| np_img = preprocess_image(image) | |
| ort_input = {ort_session.get_inputs()[0].name: np_img} | |
| # Inference | |
| ort_out = ort_session.run(None, ort_input) | |
| # Post-processing | |
| probs = 1 / (1 + np.exp(-ort_out[0][0])) | |
| return {class_name: float(conf) for class_name, conf in zip(cfg["classes"], probs)} | |
| img = gr.inputs.Image(type="pil") | |
| outputs = gr.outputs.Label(num_top_classes=1) | |
| gr.Interface( | |
| fn=predict, | |
| inputs=[img], | |
| outputs=outputs, | |
| title="PyroVision: image classification demo", | |
| article=( | |
| "<p style='text-align: center'><a href='https://github.com/pyronear/pyro-vision'>" | |
| "Github Repo</a> | " | |
| "<a href='https://pyronear.org/pyro-vision/'>Documentation</a></p>" | |
| ), | |
| live=True, | |
| ).launch() | |