BeyzaTopbas's picture
Update app.py
a199367 verified
import streamlit as st
import tensorflow as tf
import numpy as np
from PIL import Image
import os
# ================== SETTINGS ==================
IMG_SIZE = (170, 170) # same as target_size in your ImageDataGenerator
CLASS_NAMES = [
"Black Rot",
"ESCA",
"Healthy",
"Leaf Blight",
]
TFLITE_MODEL_FILENAME = "grape_disease_model.tflite" # new TFLite model
# =============================================
@st.cache_resource
def load_interpreter():
"""
Load the TFLite model once and cache the interpreter.
"""
if not os.path.exists(TFLITE_MODEL_FILENAME):
st.error(f"Model file '{TFLITE_MODEL_FILENAME}' not found in this directory.")
st.stop()
interpreter = tf.lite.Interpreter(model_path=TFLITE_MODEL_FILENAME)
interpreter.allocate_tensors()
return interpreter
def preprocess_image(img: Image.Image) -> np.ndarray:
"""
- Resize to IMG_SIZE
- Convert to numpy array
- Normalize (same as rescale=1./255 in your ImageDataGenerator)
- Add batch dimension
"""
img = img.resize(IMG_SIZE)
img_array = np.array(img)
# same normalization as during training
img_array = img_array.astype("float32") / 255.0
# (H, W, 3) -> (1, H, W, 3)
img_array = np.expand_dims(img_array, axis=0)
return img_array
def predict_image(img: Image.Image):
"""
Run inference using the TFLite interpreter.
Returns (predicted_label, confidence, all_probabilities).
"""
interpreter = load_interpreter()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
x = preprocess_image(img).astype(np.float32)
# feed input
interpreter.set_tensor(input_details[0]['index'], x)
interpreter.invoke()
# get output
preds = interpreter.get_tensor(output_details[0]['index'])
probs = preds[0]
idx = int(np.argmax(probs))
label = CLASS_NAMES[idx]
confidence = float(probs[idx])
return label, confidence, probs
# ================ STREAMLIT UI =================
st.set_page_config(
page_title="Grape Disease Classifier",
layout="centered"
)
st.title("๐Ÿ‡ Grape Disease Classifier")
st.write("App loaded. Upload an image to get a prediction.")
st.write(
"Upload a photo of a grape leaf. "
"The model predicts whether the leaf is **Black Rot**, **ESCA**, **Healthy**, "
"or **Leaf Blight**."
)
uploaded_file = st.file_uploader(
"Upload an image (.jpg, .jpeg or .png)",
type=["jpg", "jpeg", "png"]
)
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded image", use_column_width=True)
if st.button("Predict"):
with st.spinner("Running prediction..."):
label, confidence, probs = predict_image(image)
st.subheader(f"Prediction: **{label}**")
st.write(f"Confidence: **{confidence:.2%}**")
st.write("Class probabilities:")
for name, p in zip(CLASS_NAMES, probs):
st.write(f"- {name}: {float(p):.2%}")
else:
st.info("Upload an image above to get a prediction.")