indian_bird / src /streamlit_app.py
BeyzaTopbas's picture
Update src/streamlit_app.py
6c560c9 verified
import streamlit as st
import numpy as np
from PIL import Image
from pathlib import Path
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.applications.resnet50 import preprocess_input
# -----------------------------------------------------------
# 1. Config
# -----------------------------------------------------------
IMG_SIZE = (128, 128) # zelfde als bij training
# Pad naar het model: in dezelfde map als dit script (src/)
BASE_DIR = Path(__file__).parent
MODEL_PATH = BASE_DIR / "indian_birds_resnet50.h5"
# VUL HIER JOUW LABELS IN (dezelfde volgorde als tijdens training)
labels = [
"Asian Green Bee-Eater",
"Brown-Headed Barbet",
"Cattle Egret",
"Common Kingfisher",
"Common Myna",
"Common Rosefinch",
"Common Tailorbird",
"Coppersmith Barbet",
"Forest Wagtail",
"Gray Wagtail",
"Hoopoe",
"House Crow",
"Indian Grey Hornbill",
"Indian Peafowl",
"Indian Pitta",
"Indian Roller",
"Jungle Babbler",
"Northern Lapwing",
"Red Wattled Lapwing",
"Ruddy Shelduck",
"Rufous Treepie",
"Sarus Crane",
"White Wagtail",
"White-Breasted Kingfisher",
"White-Breasted Waterhen"
]
label_map = {name: i for i, name in enumerate(labels)}
inv_label_map = {v: k for k, v in label_map.items()}
# -----------------------------------------------------------
# 2. Model laden (gecached)
# -----------------------------------------------------------
@st.cache_resource
def load_bird_model():
# gebruik str(), want load_model verwacht een stringpad
model = load_model(str(MODEL_PATH))
return model
model = load_bird_model()
# -----------------------------------------------------------
# 3. Hulpfuncties
# -----------------------------------------------------------
def preprocess_image(img: Image.Image) -> np.ndarray:
"""Resize + preprocess zoals bij training."""
img = img.convert("RGB").resize(IMG_SIZE)
x = np.array(img).astype("float32")
x = preprocess_input(x) # ResNet50 preprocessing
x = np.expand_dims(x, axis=0) # batch-dimensie
return x
def predict_image(img: Image.Image):
x = preprocess_image(img)
probs = model.predict(x, verbose=0)[0] # vorm (25,)
pred_id = int(np.argmax(probs))
pred_label = inv_label_map[pred_id]
pred_prob = float(probs[pred_id])
# top-3
top3_ids = probs.argsort()[-3:][::-1]
top3 = [(inv_label_map[int(i)], float(probs[i])) for i in top3_ids]
return pred_label, pred_prob, top3
# -----------------------------------------------------------
# 4. Streamlit UI
# -----------------------------------------------------------
st.set_page_config(page_title="Indian Bird Classifier", layout="centered")
st.title("🕊️ Indian Bird Species Classifier")
st.write(
"Upload een vogelafbeelding (uit of vergelijkbaar met de "
"**25 Indian Bird Species** dataset) en het ResNet50-model "
"voorspelt de soort."
)
uploaded_file = st.file_uploader(
"Kies een afbeelding (.jpg, .png)",
type=["jpg", "jpeg", "png"]
)
if uploaded_file is not None:
# Afbeelding tonen
image = Image.open(uploaded_file)
st.image(image, caption="Geüploade afbeelding", use_column_width=True)
if st.button("🔮 Voorspel vogelsoort"):
with st.spinner("Model is aan het voorspellen..."):
pred_label, pred_prob, top3 = predict_image(image)
st.success(f"Voorspelling: **{pred_label}** ({pred_prob:.2%} zekerheid)")
st.subheader("Top 3 voorspellingen")
for name, p in top3:
st.write(f"- {name}: **{p:.2%}**")
else:
st.info("Upload een afbeelding om te starten.")