aditya-sah's picture
Create app.py
127f829 verified
raw
history blame contribute delete
784 Bytes
import tensorflow as tf
import numpy as np
import gradio as gr
from PIL import Image
# Load model
model = tf.keras.models.load_model("cifar10_custom_cnn.keras")
# CIFAR-10 class names
class_names = [
"Airplane", "Automobile", "Bird", "Cat", "Deer",
"Dog", "Frog", "Horse", "Ship", "Truck"
]
def predict(image):
image = image.resize((32, 32))
image = np.array(image) / 255.0
image = image.reshape(1, 32, 32, 3)
predictions = model.predict(image)
class_index = np.argmax(predictions)
return class_names[class_index]
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="label",
title="CIFAR-10 Image Classification",
description="Custom CNN model trained on CIFAR-10 dataset"
)
interface.launch()