DrNikDJ's picture
Update app.py
a1f6062 verified
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
# 1. Load the model
# Ensure 'my_cifar_model.keras' is uploaded to the same Space directory
model = tf.keras.models.load_model('my_cifar_model.keras')
# 2. Define the class labels (Matches CIFAR-10 order)
labels = [
'airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck'
]
def predict(img):
"""
Takes an input image, processes it, and returns
the top classification probabilities.
"""
if img is None:
return None
# Preprocessing:
# Convert to PIL Image if it's a numpy array, then resize to 32x32
img = Image.fromarray(img).resize((32, 32))
# Convert to array and normalize (0 to 1)
img_array = np.array(img).astype('float32') / 255.0
# Add batch dimension: (32, 32, 3) -> (1, 32, 32, 3)
img_array = np.expand_dims(img_array, axis=0)
# Perform prediction
predictions = model.predict(img_array).flatten()
# Apply Softmax to get probabilities (if not already in the model output)
score = tf.nn.softmax(predictions).numpy()
# Create a dictionary of {Label: Probability}
return {labels[i]: float(score[i]) for i in range(10)}
# 3. Create the Gradio Interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(),
outputs=gr.Label(num_top_classes=3),
title="CIFAR-10 Image Classifier",
description="Upload an image and the model will predict its category among 10 classes."
# REMOVED: examples=["airplane_example.jpg", "cat_example.jpg"]
)
# 4. Launch the app
if __name__ == "__main__":
demo.launch()