File size: 2,412 Bytes
37aa04f
bf220f0
 
c0227d4
bf220f0
c0227d4
 
 
bf220f0
 
c0227d4
 
 
 
 
 
 
 
 
bf220f0
 
831c84d
bf220f0
c0227d4
 
 
 
 
 
 
 
 
 
831c84d
bf220f0
 
c0227d4
 
 
bf220f0
 
 
c0227d4
bf220f0
37aa04f
831c84d
c0227d4
 
 
 
 
bf220f0
c0227d4
 
 
 
bf220f0
 
 
c0227d4
bf220f0
37aa04f
bf220f0
c0227d4
 
bf220f0
831c84d
c0227d4
831c84d
c0227d4
bf220f0
 
 
 
 
 
 
c0227d4
bf220f0
 
 
 
 
c0227d4
 
 
bf220f0
831c84d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
import numpy as np
from PIL import Image
import onnxruntime as ort

# === MODEL SETTINGS ===
MODEL_PATH = "cnn_largefish_model.onnx"
IMG_SIZE = 64

CLASS_NAMES = [
    "House Mackerel",
    "Black Sea Sprat",
    "Sea Bass",
    "Red Mullet",
    "Trout",
    "Striped Red Mullet",
    "Shrimp",
    "Gilt-Head Bream",
    "Red Sea Bream",
]


@st.cache_resource
def load_session():
    """
    Laad het ONNX-model één keer in een ONNX Runtime sessie.
    """
    session = ort.InferenceSession(
        MODEL_PATH,
        providers=["CPUExecutionProvider"],
    )
    input_name = session.get_inputs()[0].name
    return session, input_name


def preprocess_image(image: Image.Image) -> np.ndarray:
    """
    Resize + normaliseer naar (1, 64, 64, 3) met waarden 0–1.
    """
    image = image.convert("RGB")
    image = image.resize((IMG_SIZE, IMG_SIZE))
    arr = np.array(image).astype("float32") / 255.0
    arr = np.expand_dims(arr, axis=0)  # (1, 64, 64, 3)
    return arr


def predict(image: Image.Image):
    """
    Run één voorspelling via ONNX Runtime.
    """
    session, input_name = load_session()
    x = preprocess_image(image)

    # ONNX Runtime geeft een list terug; [0] is de output tensor
    preds = session.run(None, {input_name: x})[0][0]  # shape: (9,)

    pred_idx = int(np.argmax(preds))
    pred_class = CLASS_NAMES[pred_idx]
    pred_conf = float(preds[pred_idx])

    return pred_class, pred_conf, preds


# === STREAMLIT UI ===
st.set_page_config(page_title="Large-Scale Fish Classifier", page_icon="🐟")

st.title("🐟 Large-Scale Fish Classifier")
st.write("Upload een afbeelding van een vis en het model voorspelt de soort.")

uploaded_file = st.file_uploader("Upload een afbeelding", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption="Geüploade afbeelding", use_column_width=True)

    if st.button("Classify"):
        with st.spinner("Bezig met voorspellen..."):
            pred_class, pred_conf, preds = predict(image)

        st.subheader("Voorspelling")
        st.write(f"**{pred_class}** met **{pred_conf:.2%}** zekerheid.")

        st.subheader("Class probabilities")
        st.bar_chart(
            {CLASS_NAMES[i]: float(preds[i]) for i in range(len(CLASS_NAMES))}
        )
else:
    st.info("➡️ Upload eerst een afbeelding (jpg/jpeg/png).")