Spaces:
Sleeping
Sleeping
File size: 2,500 Bytes
e44b9ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | import streamlit as st
import numpy as np
from PIL import Image
import onnxruntime as ort
# === MODEL SETTINGS ===
MODEL_PATH = "cnn_largefish_model.onnx"
IMG_SIZE = 64
CLASS_NAMES = [
"House Mackerel",
"Black Sea Sprat",
"Sea Bass",
"Red Mullet",
"Trout",
"Striped Red Mullet",
"Shrimp",
"Gilt-Head Bream",
"Red Sea Bream",
]
@st.cache_resource
def load_session():
"""
Laad het ONNX-model één keer in een ONNX Runtime sessie.
"""
session = ort.InferenceSession(
MODEL_PATH,
providers=["CPUExecutionProvider"],
)
input_name = session.get_inputs()[0].name
return session, input_name
def preprocess_image(image: Image.Image) -> np.ndarray:
"""
Resize + normaliseer naar (1, 64, 64, 3) met waarden 0–1.
"""
image = image.convert("RGB")
image = image.resize((IMG_SIZE, IMG_SIZE))
arr = np.array(image).astype("float32") / 255.0
arr = np.expand_dims(arr, axis=0) # (1, 64, 64, 3)
return arr
def predict(image: Image.Image):
"""
Run één voorspelling via ONNX Runtime.
"""
session, input_name = load_session()
x = preprocess_image(image)
# ONNX Runtime geeft een list terug; [0] is de output tensor
preds = session.run(None, {input_name: x})[0][0] # shape: (9,)
pred_idx = int(np.argmax(preds))
pred_class = CLASS_NAMES[pred_idx]
pred_conf = float(preds[pred_idx])
return pred_class, pred_conf, preds
# === STREAMLIT UI ===
st.set_page_config(page_title="Large-Scale Fish Classifier", page_icon="🐟")
st.title("🐟 Large-Scale Fish Classifier")
st.write("Upload een afbeelding van een vis en het model voorspelt de soort.")
uploaded_file = st.file_uploader("Upload een afbeelding", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Geüploade afbeelding", use_column_width=True)
if st.button("Classify"):
with st.spinner("Bezig met voorspellen..."):
pred_class, pred_conf, preds = predict(image)
st.subheader("Voorspelling")
st.write(f"**{pred_class}** met **{pred_conf:.2%}** zekerheid.")
st.subheader("Class probabilities")
st.bar_chart(
{CLASS_NAMES[i]: float(preds[i]) for i in range(len(CLASS_NAMES))}
)
else:
st.info("➡️ Upload eerst een afbeelding (jpg/jpeg/png).")
|