Thrishul3549x's picture
Create app.py
1438409 verified
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import gradio as gr
from tensorflow.keras.preprocessing import image
# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# CIFAR-10 class names
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
# Build model
model = models.Sequential([
layers.Conv2D(32, (3,3), activation='relu', input_shape=(32,32,3)),
layers.MaxPooling2D((2,2)),
layers.Conv2D(64, (3,3), activation='relu'),
layers.MaxPooling2D((2,2)),
layers.Conv2D(64, (3,3), activation='relu'),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(10, activation='softmax')
])
# Compile and train briefly (use more epochs for better accuracy)
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=1, validation_data=(x_test, y_test))
# Prediction function
def predict(img):
img = image.smart_resize(img, (32, 32)) # Resize to CIFAR-10 size
img_array = np.expand_dims(img, axis=0) / 255.0
prediction = model.predict(img_array)
pred_class = np.argmax(prediction)
return {class_names[i]: float(prediction[0][i]) for i in range(10)}
# Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy"),
outputs=gr.Label(num_top_classes=3),
title="CIFAR-10 Image Classifier",
description="Upload an image and the model will predict which CIFAR-10 class it belongs to."
)
demo.launch()