cifar10-api / app.py
mramjad's picture
Update app.py
d41f497 verified
import gradio as gr
import numpy as np
from PIL import Image
import onnxruntime as ort # For ONNX inference
# Load the ONNX model
model_path = "cifar10_model.onnx"
ort_session = ort.InferenceSession(model_path)
# CIFAR-10 class labels
labels = [
"airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"
]
def preprocess_image(image):
# Resize to 32x32 and normalize
image = image.resize((32, 32))
image = np.array(image).astype(np.float32) / 255.0
# Reshape to (1, 3, 32, 32) [batch, channels, height, width]
return np.expand_dims(image.transpose(2, 0, 1), axis=0)
def predict(image):
# Preprocess the image
input_data = preprocess_image(image)
# Run inference (use the correct input name from Netron)
outputs = ort_session.run(None, {"serving_default_keras_tensor:0": input_data})[0]
predicted_class_idx = np.argmax(outputs)
return labels[predicted_class_idx]
# Create the Gradio interface
gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
title="CIFAR-10 Classifier",
description="Upload an image to classify it into one of the CIFAR-10 classes.",
).launch() # Add share=True for a public link