el-filatova's picture
Update app.py
1fdf243 verified
import pathlib
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import streamlit as st
from fastai.vision.all import load_learner
from PIL import Image
@st.cache_resource
def load_model():
return load_learner("model/export.pkl")
@st.cache_resource
def load_picture(img_path):
print(img_path)
return Image.open(img_path)
model = load_model()
if "selected_image_path" not in st.session_state:
st.session_state.selected_image_path = None
if "uploader_key" not in st.session_state:
st.session_state.uploader_key = 0
st.title("Image Upload with Drag-and-Drop")
uploaded_file = st.file_uploader(
"Drag and drop an image file here, or click to select",
type=["jpg", "jpeg", "png"],
key=f"uploader_{st.session_state.uploader_key}",
)
if uploaded_file is not None:
st.session_state.selected_image_path = uploaded_file
# image = Image.open(uploaded_file)
image = load_picture(st.session_state.selected_image_path)
st.image(image, caption="Uploaded Image", width=200)
st.write("Image uploaded successfully!")
def get_predictions(image, model):
labels = [
"apple",
"banana",
"beetroot",
"bell pepper",
"cabbage",
"capsicum",
"carrot",
"cauliflower",
"chilli pepper",
"corn",
"cucumber",
"eggplant",
"garlic",
"ginger",
"grapes",
"jalepeno",
"kiwi",
"lemon",
"lettuce",
"mango",
"onion",
"orange",
"paprika",
"pear",
"peas",
"pineapple",
"pomegranate",
"potato",
"raddish",
"soy beans",
"spinach",
"sweetcorn",
"sweetpotato",
"tomato",
"turnip",
"watermelon",
]
prediction = model.predict(image)
results = pd.DataFrame(
{"labels": labels, "predictions": prediction[2].numpy().round(6)}
).sort_values(by="predictions", ascending=False)
return results, prediction[0]
st.title("Image Classification")
preselected_images = {
"Apple": "data/apple_Image_1.jpg",
"Banana": "data/banana_Image_1.jpg",
"Carrot": "data/carrot_Image_1.jpg",
"Tomato": "data/tomato_Image_1.jpg",
"Corn": "data/corn_Image_1.jpg",
}
st.subheader("Try Classifying These Images")
st.write("Select an image")
columns = st.columns(len(preselected_images))
def update_key():
st.session_state.uploader_key += 1
for i, (label, img_path) in enumerate(preselected_images.items()):
with columns[i]:
# presel_image = Image.open(img_path)
presel_image = load_picture(img_path)
presel_image = presel_image.resize((150, 150))
st.image(presel_image, caption=label, use_container_width=True)
if st.button(f"Select {label}", on_click=update_key):
st.session_state.selected_image_path = img_path
st.write(f"Selected image: {label}")
# image = Image.open(img_path)
image = load_picture(img_path)
st.image(image, caption="Uploaded Image", width=200)
st.write("Click the button below to classify the uploaded image.")
if st.button("Submit"):
if st.session_state.selected_image_path:
# image = Image.open(st.session_state.selected_image_path)
image = load_picture(st.session_state.selected_image_path)
st.image(image, caption="Selected Image", width=200)
results, prediction = get_predictions(image, model)
st.write(f"Prediction: {prediction}")
fig, ax = plt.subplots(figsize=(8, 5))
ax = sns.barplot(
data=results.iloc[:5],
x="predictions",
y="labels",
orient="h",
)
st.pyplot(fig)