Spaces:
Sleeping
Sleeping
File size: 2,412 Bytes
37aa04f bf220f0 c0227d4 bf220f0 c0227d4 bf220f0 c0227d4 bf220f0 831c84d bf220f0 c0227d4 831c84d bf220f0 c0227d4 bf220f0 c0227d4 bf220f0 37aa04f 831c84d c0227d4 bf220f0 c0227d4 bf220f0 c0227d4 bf220f0 37aa04f bf220f0 c0227d4 bf220f0 831c84d c0227d4 831c84d c0227d4 bf220f0 c0227d4 bf220f0 c0227d4 bf220f0 831c84d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | import streamlit as st
import numpy as np
from PIL import Image
import onnxruntime as ort
# === MODEL SETTINGS ===
MODEL_PATH = "cnn_largefish_model.onnx"
IMG_SIZE = 64
CLASS_NAMES = [
"House Mackerel",
"Black Sea Sprat",
"Sea Bass",
"Red Mullet",
"Trout",
"Striped Red Mullet",
"Shrimp",
"Gilt-Head Bream",
"Red Sea Bream",
]
@st.cache_resource
def load_session():
"""
Laad het ONNX-model één keer in een ONNX Runtime sessie.
"""
session = ort.InferenceSession(
MODEL_PATH,
providers=["CPUExecutionProvider"],
)
input_name = session.get_inputs()[0].name
return session, input_name
def preprocess_image(image: Image.Image) -> np.ndarray:
"""
Resize + normaliseer naar (1, 64, 64, 3) met waarden 0–1.
"""
image = image.convert("RGB")
image = image.resize((IMG_SIZE, IMG_SIZE))
arr = np.array(image).astype("float32") / 255.0
arr = np.expand_dims(arr, axis=0) # (1, 64, 64, 3)
return arr
def predict(image: Image.Image):
"""
Run één voorspelling via ONNX Runtime.
"""
session, input_name = load_session()
x = preprocess_image(image)
# ONNX Runtime geeft een list terug; [0] is de output tensor
preds = session.run(None, {input_name: x})[0][0] # shape: (9,)
pred_idx = int(np.argmax(preds))
pred_class = CLASS_NAMES[pred_idx]
pred_conf = float(preds[pred_idx])
return pred_class, pred_conf, preds
# === STREAMLIT UI ===
st.set_page_config(page_title="Large-Scale Fish Classifier", page_icon="🐟")
st.title("🐟 Large-Scale Fish Classifier")
st.write("Upload een afbeelding van een vis en het model voorspelt de soort.")
uploaded_file = st.file_uploader("Upload een afbeelding", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Geüploade afbeelding", use_column_width=True)
if st.button("Classify"):
with st.spinner("Bezig met voorspellen..."):
pred_class, pred_conf, preds = predict(image)
st.subheader("Voorspelling")
st.write(f"**{pred_class}** met **{pred_conf:.2%}** zekerheid.")
st.subheader("Class probabilities")
st.bar_chart(
{CLASS_NAMES[i]: float(preds[i]) for i in range(len(CLASS_NAMES))}
)
else:
st.info("➡️ Upload eerst een afbeelding (jpg/jpeg/png).")
|