Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import onnxruntime as ort # For ONNX inference | |
| # Load the ONNX model | |
| model_path = "cifar10_model.onnx" | |
| ort_session = ort.InferenceSession(model_path) | |
| # CIFAR-10 class labels | |
| labels = [ | |
| "airplane", "automobile", "bird", "cat", "deer", | |
| "dog", "frog", "horse", "ship", "truck" | |
| ] | |
| def preprocess_image(image): | |
| # Resize to 32x32 and normalize | |
| image = image.resize((32, 32)) | |
| image = np.array(image).astype(np.float32) / 255.0 | |
| # Reshape to (1, 3, 32, 32) [batch, channels, height, width] | |
| return np.expand_dims(image.transpose(2, 0, 1), axis=0) | |
| def predict(image): | |
| # Preprocess the image | |
| input_data = preprocess_image(image) | |
| # Run inference (use the correct input name from Netron) | |
| outputs = ort_session.run(None, {"serving_default_keras_tensor:0": input_data})[0] | |
| predicted_class_idx = np.argmax(outputs) | |
| return labels[predicted_class_idx] | |
| # Create the Gradio interface | |
| gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=3), | |
| title="CIFAR-10 Classifier", | |
| description="Upload an image to classify it into one of the CIFAR-10 classes.", | |
| ).launch() # Add share=True for a public link |