import streamlit as st import numpy as np from PIL import Image import onnxruntime as ort # === MODEL SETTINGS === MODEL_PATH = "cnn_largefish_model.onnx" IMG_SIZE = 64 CLASS_NAMES = [ "House Mackerel", "Black Sea Sprat", "Sea Bass", "Red Mullet", "Trout", "Striped Red Mullet", "Shrimp", "Gilt-Head Bream", "Red Sea Bream", ] @st.cache_resource def load_session(): """ Laad het ONNX-model één keer in een ONNX Runtime sessie. """ session = ort.InferenceSession( MODEL_PATH, providers=["CPUExecutionProvider"], ) input_name = session.get_inputs()[0].name return session, input_name def preprocess_image(image: Image.Image) -> np.ndarray: """ Resize + normaliseer naar (1, 64, 64, 3) met waarden 0–1. """ image = image.convert("RGB") image = image.resize((IMG_SIZE, IMG_SIZE)) arr = np.array(image).astype("float32") / 255.0 arr = np.expand_dims(arr, axis=0) # (1, 64, 64, 3) return arr def predict(image: Image.Image): """ Run één voorspelling via ONNX Runtime. """ session, input_name = load_session() x = preprocess_image(image) # ONNX Runtime geeft een list terug; [0] is de output tensor preds = session.run(None, {input_name: x})[0][0] # shape: (9,) pred_idx = int(np.argmax(preds)) pred_class = CLASS_NAMES[pred_idx] pred_conf = float(preds[pred_idx]) return pred_class, pred_conf, preds # === STREAMLIT UI === st.set_page_config(page_title="Large-Scale Fish Classifier", page_icon="🐟") st.title("🐟 Large-Scale Fish Classifier") st.write("Upload een afbeelding van een vis en het model voorspelt de soort.") uploaded_file = st.file_uploader("Upload een afbeelding", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Geüploade afbeelding", use_column_width=True) if st.button("Classify"): with st.spinner("Bezig met voorspellen..."): pred_class, pred_conf, preds = predict(image) st.subheader("Voorspelling") st.write(f"**{pred_class}** met **{pred_conf:.2%}** zekerheid.") st.subheader("Class probabilities") st.bar_chart( {CLASS_NAMES[i]: float(preds[i]) for i in range(len(CLASS_NAMES))} ) else: st.info("➡️ Upload eerst een afbeelding (jpg/jpeg/png).")