mramjad's picture
Update app.py
1db03ad verified
import gradio as gr
import numpy as np
from PIL import Image
import onnxruntime as ort
# Load the ONNX model
model_path = "cifar10_model.onnx"
ort_session = ort.InferenceSession(model_path)
# CIFAR-10 class labels
labels = [
"airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"
]
def preprocess_image(image):
# Resize to 32x32 and normalize
image = image.resize((32, 32))
image = np.array(image).astype(np.float32) / 255.0
# Reshape to (1, 3, 32, 32) [batch, channels, height, width]
return np.expand_dims(image.transpose(2, 0, 1), axis=0)
def predict(image):
# Preprocess the image
input_data = preprocess_image(image)
# Run inference
outputs = ort_session.run(None, {"serving_default_keras_tensor:0": input_data})[0]
predicted_class_idx = np.argmax(outputs)
return labels[predicted_class_idx]
# Gradio Interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
title="CIFAR-10 Classifier",
description="Upload an image to classify it into CIFAR-10 categories.",
allow_flagging="never" # Disable flagging
)
# Launch the Gradio app and expose a public URL
interface.launch(server_name="0.0.0.0", server_port=7860, share=True)