File size: 2,187 Bytes
0c425a8 ddedb37 0c425a8 ddedb37 8b4cd5f ddedb37 0c425a8 ddedb37 0c425a8 ddedb37 0c425a8 ddedb37 4520d1e ddedb37 0c425a8 ddedb37 0c425a8 ddedb37 0c425a8 ddedb37 0c425a8 ddedb37 0c425a8 594e520 ddedb37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import os
import numpy as np
import streamlit as st
from PIL import Image
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
# CONFIG
MODEL_PATH = "custom_cnn_last4_finetuned.h5"
IMG_SIZE = (256, 256)
CLASS_NAMES = [
"animal fish",
"animal fish bass",
"fish sea_food black_sea_sprat",
"fish sea_food gilt_head_bream",
"fish sea_food hourse_mackerel",
"fish sea_food red_mullet",
"fish sea_food red_sea_bream",
"fish sea_food sea_bass",
"fish sea_food shrimp",
"fish sea_food striped_red_mullet",
"fish sea_food trout"
]
st.set_page_config(page_title="Custom CNN Fish Classifier", layout="centered")
st.title("🐟 Fish Classifier")
# LOAD MODEL
@st.cache_resource
def load_cnn_model():
try:
model = load_model(MODEL_PATH, compile=False)
return model
except Exception as e:
st.error(f"Model loading failed:\n{e}")
st.info("""
**Upload your model file to this Space:**
File must be named: `custom_cnn_last4_finetuned.h5`
""")
return None
model = load_cnn_model()
if model is None:
# Show what files exist
if os.path.exists("."):
st.write("Files in this space:")
for f in os.listdir("."):
st.write(f"- {f}")
st.stop()
# PREPROCESS IMAGE
def prepare_image(pil_img):
pil_img = pil_img.convert("RGB")
pil_img = pil_img.resize((IMG_SIZE[1], IMG_SIZE[0]))
arr = img_to_array(pil_img)
arr = arr / 255.0 # Normalize to 0-1
arr = np.expand_dims(arr, axis=0)
return arr
# PREDICT
def predict_top1(img):
x = prepare_image(img)
preds = model.predict(x, verbose=0)[0]
top_index = np.argmax(preds)
return CLASS_NAMES[top_index], float(preds[top_index])
# UI
uploaded = st.file_uploader("Upload fish image", type=["jpg", "jpeg", "png"])
if uploaded:
img = Image.open(uploaded)
st.image(img, caption="Uploaded Image", use_container_width=True)
if st.button("Predict"):
label, prob = predict_top1(img)
st.markdown(f"## 🎯 Prediction: **{label}**")
st.markdown(f"### Confidence: **{prob*100:.2f}%**") |