Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| from PIL import Image | |
| import onnxruntime as ort | |
| # === MODEL SETTINGS === | |
| MODEL_PATH = "cnn_largefish_model.onnx" | |
| IMG_SIZE = 64 | |
| CLASS_NAMES = [ | |
| "House Mackerel", | |
| "Black Sea Sprat", | |
| "Sea Bass", | |
| "Red Mullet", | |
| "Trout", | |
| "Striped Red Mullet", | |
| "Shrimp", | |
| "Gilt-Head Bream", | |
| "Red Sea Bream", | |
| ] | |
| def load_session(): | |
| """ | |
| Laad het ONNX-model één keer in een ONNX Runtime sessie. | |
| """ | |
| session = ort.InferenceSession( | |
| MODEL_PATH, | |
| providers=["CPUExecutionProvider"], | |
| ) | |
| input_name = session.get_inputs()[0].name | |
| return session, input_name | |
| def preprocess_image(image: Image.Image) -> np.ndarray: | |
| """ | |
| Resize + normaliseer naar (1, 64, 64, 3) met waarden 0–1. | |
| """ | |
| image = image.convert("RGB") | |
| image = image.resize((IMG_SIZE, IMG_SIZE)) | |
| arr = np.array(image).astype("float32") / 255.0 | |
| arr = np.expand_dims(arr, axis=0) # (1, 64, 64, 3) | |
| return arr | |
| def predict(image: Image.Image): | |
| """ | |
| Run één voorspelling via ONNX Runtime. | |
| """ | |
| session, input_name = load_session() | |
| x = preprocess_image(image) | |
| # ONNX Runtime geeft een list terug; [0] is de output tensor | |
| preds = session.run(None, {input_name: x})[0][0] # shape: (9,) | |
| pred_idx = int(np.argmax(preds)) | |
| pred_class = CLASS_NAMES[pred_idx] | |
| pred_conf = float(preds[pred_idx]) | |
| return pred_class, pred_conf, preds | |
| # === STREAMLIT UI === | |
| st.set_page_config(page_title="Large-Scale Fish Classifier", page_icon="🐟") | |
| st.title("🐟 Large-Scale Fish Classifier") | |
| st.write("Upload een afbeelding van een vis en het model voorspelt de soort.") | |
| uploaded_file = st.file_uploader("Upload een afbeelding", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Geüploade afbeelding", use_column_width=True) | |
| if st.button("Classify"): | |
| with st.spinner("Bezig met voorspellen..."): | |
| pred_class, pred_conf, preds = predict(image) | |
| st.subheader("Voorspelling") | |
| st.write(f"**{pred_class}** met **{pred_conf:.2%}** zekerheid.") | |
| st.subheader("Class probabilities") | |
| st.bar_chart( | |
| {CLASS_NAMES[i]: float(preds[i]) for i in range(len(CLASS_NAMES))} | |
| ) | |
| else: | |
| st.info("➡️ Upload eerst een afbeelding (jpg/jpeg/png).") | |