CNN / app.py
Pant0x's picture
Update app.py
42cbca5 verified
import streamlit as st
import tensorflow as tf
import numpy as np
from PIL import Image
import json
st.set_page_config(
page_title="CIFAR-10 Classifier",
page_icon="πŸ–ΌοΈ",
layout="centered",
)
st.title("πŸš€ CIFAR-10 Image Classifier")
st.markdown("Upload an image and see what the model predicts!")
@st.cache_resource
def load_model_and_labels():
model = tf.keras.models.load_model("models/cifar10_cnn.keras")
with open("models/labels_map.json", "r") as f:
labels = json.load(f)
return model, labels
model, labels = load_model_and_labels()
uploaded_file = st.file_uploader("Upload an image (PNG/JPG)", type=["png","jpg","jpeg"])
if uploaded_file:
img = Image.open(uploaded_file).convert("RGB")
st.image(img, caption="Uploaded Image", use_column_width=False)
def preprocess_image(img):
img = img.resize((32,32))
img = np.array(img)/255.0
return img
x = preprocess_image(img)
with st.spinner("Predicting..."):
x_input = x.reshape(1,32,32,3)
preds = model.predict(x_input)[0]
top_idx = preds.argsort()[-3:][::-1]
st.subheader("βœ… Prediction")
st.write(f"**Top-1:** {labels[str(top_idx[0])]} ({preds[top_idx[0]]*100:.2f}%)")
st.subheader("πŸ“Š Top-3 Predictions")
for i in top_idx:
st.write(f"{labels[str(i)]}: {preds[i]*100:.2f}%")
st.subheader("πŸ“ˆ All Class Probabilities")
st.bar_chart({labels[str(i)]: float(preds[i]) for i in range(len(labels))})
else:
st.info("Upload an image to see predictions.")