import streamlit as st import numpy as np import pandas as pd from PIL import Image import tensorflow as tf # --- PAS DIT AAN ALS JE ANDERE KLASSEN HEBT --- CLASS_NAMES = [ "Ajwa", "Galaxy", "Medjool", "Meneifi", "Nabtat Ali", "Rutab", "Shaishe", "Sokari", "Sugaey" ] IMG_SIZE = (128, 128) # dezelfde grootte als in je model @st.cache_resource def load_model(): """Laad het Keras-model één keer en cache het.""" model = tf.keras.models.load_model("date_fruit_model.h5") return model def preprocess_image(image: Image.Image) -> np.ndarray: """ Maakt een PIL-image klaar voor het model: - resize - naar np.array - normaliseren [0,1] - batch-dimensie toevoegen """ image = image.convert("RGB") # voor de zekerheid image = image.resize(IMG_SIZE) arr = np.array(image, dtype="float32") / 255.0 arr = np.expand_dims(arr, axis=0) # shape: (1, 128, 128, 3) return arr def main(): st.set_page_config(page_title="Date Fruit Classifier", layout="centered") st.title("🍇 Date Fruit Classifier") st.write("Upload een foto van een dadel en het model probeert de soort te raden.") # Sidebar info st.sidebar.header("Info") st.sidebar.write("Model: Convolutional Neural Network (Keras/TensorFlow)") st.sidebar.write(f"Aantal klassen: **{len(CLASS_NAMES)}**") uploaded_file = st.file_uploader( "Kies een afbeelding", type=["jpg", "jpeg", "png"] ) if uploaded_file is not None: # Toon de geüploade afbeelding image = Image.open(uploaded_file) st.image(image, caption="Geüploade afbeelding", use_container_width=True) if st.button("Classificeer"): with st.spinner("Bezig met voorspellen..."): model = load_model() input_arr = preprocess_image(image) preds = model.predict(input_arr)[0] # shape: (n_classes,) pred_idx = int(np.argmax(preds)) pred_name = CLASS_NAMES[pred_idx] pred_conf = float(preds[pred_idx]) st.subheader("🔎 Voorspelling") st.write(f"**Klasse:** {pred_name}") st.write(f"**Vertrouwen:** {pred_conf:.2%}") # Probabilities per klasse st.subheader("📊 Waarschijnlijkheid per klasse") probs_df = pd.DataFrame({ "Klasse": CLASS_NAMES, "Score": preds }) probs_df = probs_df.set_index("Klasse") st.bar_chart(probs_df) if __name__ == "__main__": main()