Spaces:
Build error
Build error
| import pathlib | |
| import time | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import seaborn as sns | |
| import streamlit as st | |
| from fastai.vision.all import load_learner | |
| from PIL import Image | |
| def load_model(): | |
| return load_learner("model/export.pkl") | |
| def load_picture(img_path): | |
| print(img_path) | |
| return Image.open(img_path) | |
| model = load_model() | |
| if "selected_image_path" not in st.session_state: | |
| st.session_state.selected_image_path = None | |
| if "uploader_key" not in st.session_state: | |
| st.session_state.uploader_key = 0 | |
| st.title("Image Upload with Drag-and-Drop") | |
| uploaded_file = st.file_uploader( | |
| "Drag and drop an image file here, or click to select", | |
| type=["jpg", "jpeg", "png"], | |
| key=f"uploader_{st.session_state.uploader_key}", | |
| ) | |
| if uploaded_file is not None: | |
| st.session_state.selected_image_path = uploaded_file | |
| # image = Image.open(uploaded_file) | |
| image = load_picture(st.session_state.selected_image_path) | |
| st.image(image, caption="Uploaded Image", width=200) | |
| st.write("Image uploaded successfully!") | |
| def get_predictions(image, model): | |
| labels = [ | |
| "apple", | |
| "banana", | |
| "beetroot", | |
| "bell pepper", | |
| "cabbage", | |
| "capsicum", | |
| "carrot", | |
| "cauliflower", | |
| "chilli pepper", | |
| "corn", | |
| "cucumber", | |
| "eggplant", | |
| "garlic", | |
| "ginger", | |
| "grapes", | |
| "jalepeno", | |
| "kiwi", | |
| "lemon", | |
| "lettuce", | |
| "mango", | |
| "onion", | |
| "orange", | |
| "paprika", | |
| "pear", | |
| "peas", | |
| "pineapple", | |
| "pomegranate", | |
| "potato", | |
| "raddish", | |
| "soy beans", | |
| "spinach", | |
| "sweetcorn", | |
| "sweetpotato", | |
| "tomato", | |
| "turnip", | |
| "watermelon", | |
| ] | |
| prediction = model.predict(image) | |
| results = pd.DataFrame( | |
| {"labels": labels, "predictions": prediction[2].numpy().round(6)} | |
| ).sort_values(by="predictions", ascending=False) | |
| return results, prediction[0] | |
| st.title("Image Classification") | |
| preselected_images = { | |
| "Apple": "data/apple_Image_1.jpg", | |
| "Banana": "data/banana_Image_1.jpg", | |
| "Carrot": "data/carrot_Image_1.jpg", | |
| "Tomato": "data/tomato_Image_1.jpg", | |
| "Corn": "data/corn_Image_1.jpg", | |
| } | |
| st.subheader("Try Classifying These Images") | |
| st.write("Select an image") | |
| columns = st.columns(len(preselected_images)) | |
| def update_key(): | |
| st.session_state.uploader_key += 1 | |
| for i, (label, img_path) in enumerate(preselected_images.items()): | |
| with columns[i]: | |
| # presel_image = Image.open(img_path) | |
| presel_image = load_picture(img_path) | |
| presel_image = presel_image.resize((150, 150)) | |
| st.image(presel_image, caption=label, use_container_width=True) | |
| if st.button(f"Select {label}", on_click=update_key): | |
| st.session_state.selected_image_path = img_path | |
| st.write(f"Selected image: {label}") | |
| # image = Image.open(img_path) | |
| image = load_picture(img_path) | |
| st.image(image, caption="Uploaded Image", width=200) | |
| st.write("Click the button below to classify the uploaded image.") | |
| if st.button("Submit"): | |
| if st.session_state.selected_image_path: | |
| # image = Image.open(st.session_state.selected_image_path) | |
| image = load_picture(st.session_state.selected_image_path) | |
| st.image(image, caption="Selected Image", width=200) | |
| results, prediction = get_predictions(image, model) | |
| st.write(f"Prediction: {prediction}") | |
| fig, ax = plt.subplots(figsize=(8, 5)) | |
| ax = sns.barplot( | |
| data=results.iloc[:5], | |
| x="predictions", | |
| y="labels", | |
| orient="h", | |
| ) | |
| st.pyplot(fig) |