RiceClassification / src /streamlit_app.py
Osmanerendgn's picture
Update src/streamlit_app.py
4b21fe5 verified
import streamlit as st
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras.models import load_model
from pathlib import Path
st.set_page_config(page_title="Rice Classification", page_icon="🍚", layout="centered")
MODEL_PATH = Path(__file__).resolve().parents[1] / "src/rice_efficientnet_feature_extractor.keras"
CLASS_NAMES = ["Arborio", "Basmati", "Ipsala", "Jasmine", "Karacadag"]
@st.cache_resource
def load_cached_model():
return load_model(MODEL_PATH)
model = load_cached_model()
def preprocess(File):
img = Image.open(File).convert("RGB")
img = img.resize((224,224))
x = np.array(img)
x = np.expand_dims(x,axis=0)
return img, x
"""def topk(prob, k=3):
idx = np.argsort(prob)[::-1][:k]
return idx, prob[idx]
"""
st.title("🍚 Rice Classification (Transfer Learning)")
st.write("Upload an image and get the predicted rice type.")
"""try:
model = load_model()
except Exception as e:
st.error(f"Model load failed. Check MODEL_PATH.\n\nError: {e}")
st.stop()"""
file = st.file_uploader("Upload an image", type=["jpg", "jpeg"])
if file:
pil_img, x = preprocess(file)
preds = model.predict(x, verbose=0)
preds = np.array(preds)
st.image(pil_img, caption="Uploaded image", use_container_width=True)
prob = preds[0]
best_idx = np.argmax(prob)
best_label = CLASS_NAMES[best_idx]
best_conf = prob[best_idx]
st.subheader("Prediction")
st.success(f"{best_label} | confidence: {best_conf}")
else:
st.caption("No image uploaded yet.")