import pathlib import time import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns import streamlit as st from fastai.vision.all import load_learner from PIL import Image @st.cache_resource def load_model(): return load_learner("model/export.pkl") @st.cache_resource def load_picture(img_path): print(img_path) return Image.open(img_path) model = load_model() if "selected_image_path" not in st.session_state: st.session_state.selected_image_path = None if "uploader_key" not in st.session_state: st.session_state.uploader_key = 0 st.title("Image Upload with Drag-and-Drop") uploaded_file = st.file_uploader( "Drag and drop an image file here, or click to select", type=["jpg", "jpeg", "png"], key=f"uploader_{st.session_state.uploader_key}", ) if uploaded_file is not None: st.session_state.selected_image_path = uploaded_file # image = Image.open(uploaded_file) image = load_picture(st.session_state.selected_image_path) st.image(image, caption="Uploaded Image", width=200) st.write("Image uploaded successfully!") def get_predictions(image, model): labels = [ "apple", "banana", "beetroot", "bell pepper", "cabbage", "capsicum", "carrot", "cauliflower", "chilli pepper", "corn", "cucumber", "eggplant", "garlic", "ginger", "grapes", "jalepeno", "kiwi", "lemon", "lettuce", "mango", "onion", "orange", "paprika", "pear", "peas", "pineapple", "pomegranate", "potato", "raddish", "soy beans", "spinach", "sweetcorn", "sweetpotato", "tomato", "turnip", "watermelon", ] prediction = model.predict(image) results = pd.DataFrame( {"labels": labels, "predictions": prediction[2].numpy().round(6)} ).sort_values(by="predictions", ascending=False) return results, prediction[0] st.title("Image Classification") preselected_images = { "Apple": "data/apple_Image_1.jpg", "Banana": "data/banana_Image_1.jpg", "Carrot": "data/carrot_Image_1.jpg", "Tomato": "data/tomato_Image_1.jpg", "Corn": "data/corn_Image_1.jpg", } st.subheader("Try Classifying These Images") st.write("Select an image") columns = st.columns(len(preselected_images)) def update_key(): st.session_state.uploader_key += 1 for i, (label, img_path) in enumerate(preselected_images.items()): with columns[i]: # presel_image = Image.open(img_path) presel_image = load_picture(img_path) presel_image = presel_image.resize((150, 150)) st.image(presel_image, caption=label, use_container_width=True) if st.button(f"Select {label}", on_click=update_key): st.session_state.selected_image_path = img_path st.write(f"Selected image: {label}") # image = Image.open(img_path) image = load_picture(img_path) st.image(image, caption="Uploaded Image", width=200) st.write("Click the button below to classify the uploaded image.") if st.button("Submit"): if st.session_state.selected_image_path: # image = Image.open(st.session_state.selected_image_path) image = load_picture(st.session_state.selected_image_path) st.image(image, caption="Selected Image", width=200) results, prediction = get_predictions(image, model) st.write(f"Prediction: {prediction}") fig, ax = plt.subplots(figsize=(8, 5)) ax = sns.barplot( data=results.iloc[:5], x="predictions", y="labels", orient="h", ) st.pyplot(fig)