|
|
import os |
|
|
import numpy as np |
|
|
import streamlit as st |
|
|
from PIL import Image |
|
|
from tensorflow.keras.models import load_model |
|
|
from tensorflow.keras.preprocessing.image import img_to_array |
|
|
|
|
|
|
|
|
MODEL_PATH = "custom_cnn_last4_finetuned.h5" |
|
|
IMG_SIZE = (256, 256) |
|
|
|
|
|
CLASS_NAMES = [ |
|
|
"animal fish", |
|
|
"animal fish bass", |
|
|
"fish sea_food black_sea_sprat", |
|
|
"fish sea_food gilt_head_bream", |
|
|
"fish sea_food hourse_mackerel", |
|
|
"fish sea_food red_mullet", |
|
|
"fish sea_food red_sea_bream", |
|
|
"fish sea_food sea_bass", |
|
|
"fish sea_food shrimp", |
|
|
"fish sea_food striped_red_mullet", |
|
|
"fish sea_food trout" |
|
|
] |
|
|
|
|
|
st.set_page_config(page_title="Custom CNN Fish Classifier", layout="centered") |
|
|
st.title("π Fish Classifier") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_cnn_model(): |
|
|
try: |
|
|
model = load_model(MODEL_PATH, compile=False) |
|
|
return model |
|
|
except Exception as e: |
|
|
st.error(f"Model loading failed:\n{e}") |
|
|
st.info(""" |
|
|
**Upload your model file to this Space:** |
|
|
File must be named: `custom_cnn_last4_finetuned.h5` |
|
|
""") |
|
|
return None |
|
|
|
|
|
model = load_cnn_model() |
|
|
|
|
|
if model is None: |
|
|
|
|
|
if os.path.exists("."): |
|
|
st.write("Files in this space:") |
|
|
for f in os.listdir("."): |
|
|
st.write(f"- {f}") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
def prepare_image(pil_img): |
|
|
pil_img = pil_img.convert("RGB") |
|
|
pil_img = pil_img.resize((IMG_SIZE[1], IMG_SIZE[0])) |
|
|
|
|
|
arr = img_to_array(pil_img) |
|
|
arr = arr / 255.0 |
|
|
arr = np.expand_dims(arr, axis=0) |
|
|
|
|
|
return arr |
|
|
|
|
|
|
|
|
def predict_top1(img): |
|
|
x = prepare_image(img) |
|
|
preds = model.predict(x, verbose=0)[0] |
|
|
top_index = np.argmax(preds) |
|
|
|
|
|
return CLASS_NAMES[top_index], float(preds[top_index]) |
|
|
|
|
|
|
|
|
uploaded = st.file_uploader("Upload fish image", type=["jpg", "jpeg", "png"]) |
|
|
|
|
|
if uploaded: |
|
|
img = Image.open(uploaded) |
|
|
st.image(img, caption="Uploaded Image", use_container_width=True) |
|
|
|
|
|
if st.button("Predict"): |
|
|
label, prob = predict_top1(img) |
|
|
st.markdown(f"## π― Prediction: **{label}**") |
|
|
st.markdown(f"### Confidence: **{prob*100:.2f}%**") |