File size: 2,187 Bytes
0c425a8
 
 
 
ddedb37
 
0c425a8
ddedb37
8b4cd5f
ddedb37
0c425a8
ddedb37
 
 
 
 
 
 
 
 
 
 
 
 
0c425a8
ddedb37
 
0c425a8
ddedb37
 
 
 
 
 
 
 
 
 
4520d1e
ddedb37
 
0c425a8
ddedb37
0c425a8
ddedb37
 
 
 
 
 
0c425a8
 
ddedb37
 
 
 
 
 
 
 
0c425a8
ddedb37
 
 
 
 
 
 
 
 
 
 
 
0c425a8
594e520
 
ddedb37
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import numpy as np
import streamlit as st
from PIL import Image
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array

# CONFIG
MODEL_PATH = "custom_cnn_last4_finetuned.h5"
IMG_SIZE = (256, 256)

CLASS_NAMES = [
    "animal fish",
    "animal fish bass",
    "fish sea_food black_sea_sprat",
    "fish sea_food gilt_head_bream",
    "fish sea_food hourse_mackerel",
    "fish sea_food red_mullet",
    "fish sea_food red_sea_bream",
    "fish sea_food sea_bass",
    "fish sea_food shrimp",
    "fish sea_food striped_red_mullet",
    "fish sea_food trout"
]

st.set_page_config(page_title="Custom CNN Fish Classifier", layout="centered")
st.title("🐟 Fish Classifier")

# LOAD MODEL
@st.cache_resource
def load_cnn_model():
    try:
        model = load_model(MODEL_PATH, compile=False)
        return model
    except Exception as e:
        st.error(f"Model loading failed:\n{e}")
        st.info("""
        **Upload your model file to this Space:**
        File must be named: `custom_cnn_last4_finetuned.h5`
        """)
        return None

model = load_cnn_model()

if model is None:
    # Show what files exist
    if os.path.exists("."):
        st.write("Files in this space:")
        for f in os.listdir("."):
            st.write(f"- {f}")
    st.stop()

# PREPROCESS IMAGE
def prepare_image(pil_img):
    pil_img = pil_img.convert("RGB")
    pil_img = pil_img.resize((IMG_SIZE[1], IMG_SIZE[0]))

    arr = img_to_array(pil_img)
    arr = arr / 255.0  # Normalize to 0-1
    arr = np.expand_dims(arr, axis=0)

    return arr

# PREDICT
def predict_top1(img):
    x = prepare_image(img)
    preds = model.predict(x, verbose=0)[0]
    top_index = np.argmax(preds)

    return CLASS_NAMES[top_index], float(preds[top_index])

# UI
uploaded = st.file_uploader("Upload fish image", type=["jpg", "jpeg", "png"])

if uploaded:
    img = Image.open(uploaded)
    st.image(img, caption="Uploaded Image", use_container_width=True)

    if st.button("Predict"):
        label, prob = predict_top1(img)
        st.markdown(f"## 🎯 Prediction: **{label}**")
        st.markdown(f"### Confidence: **{prob*100:.2f}%**")