prasanthr0416's picture
Update app.py
4520d1e verified
import os
import numpy as np
import streamlit as st
from PIL import Image
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
# CONFIG
MODEL_PATH = "custom_cnn_last4_finetuned.h5"
IMG_SIZE = (256, 256)
CLASS_NAMES = [
"animal fish",
"animal fish bass",
"fish sea_food black_sea_sprat",
"fish sea_food gilt_head_bream",
"fish sea_food hourse_mackerel",
"fish sea_food red_mullet",
"fish sea_food red_sea_bream",
"fish sea_food sea_bass",
"fish sea_food shrimp",
"fish sea_food striped_red_mullet",
"fish sea_food trout"
]
st.set_page_config(page_title="Custom CNN Fish Classifier", layout="centered")
st.title("🐟 Fish Classifier")
# LOAD MODEL
@st.cache_resource
def load_cnn_model():
try:
model = load_model(MODEL_PATH, compile=False)
return model
except Exception as e:
st.error(f"Model loading failed:\n{e}")
st.info("""
**Upload your model file to this Space:**
File must be named: `custom_cnn_last4_finetuned.h5`
""")
return None
model = load_cnn_model()
if model is None:
# Show what files exist
if os.path.exists("."):
st.write("Files in this space:")
for f in os.listdir("."):
st.write(f"- {f}")
st.stop()
# PREPROCESS IMAGE
def prepare_image(pil_img):
pil_img = pil_img.convert("RGB")
pil_img = pil_img.resize((IMG_SIZE[1], IMG_SIZE[0]))
arr = img_to_array(pil_img)
arr = arr / 255.0 # Normalize to 0-1
arr = np.expand_dims(arr, axis=0)
return arr
# PREDICT
def predict_top1(img):
x = prepare_image(img)
preds = model.predict(x, verbose=0)[0]
top_index = np.argmax(preds)
return CLASS_NAMES[top_index], float(preds[top_index])
# UI
uploaded = st.file_uploader("Upload fish image", type=["jpg", "jpeg", "png"])
if uploaded:
img = Image.open(uploaded)
st.image(img, caption="Uploaded Image", use_container_width=True)
if st.button("Predict"):
label, prob = predict_top1(img)
st.markdown(f"## 🎯 Prediction: **{label}**")
st.markdown(f"### Confidence: **{prob*100:.2f}%**")