|
|
import streamlit as st
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
import plotly.graph_objects as go
|
|
|
import tensorflow as tf
|
|
|
import time
|
|
|
import os
|
|
|
|
|
|
|
|
|
@st.cache_resource
|
|
|
def load_model():
|
|
|
return tf.keras.models.load_model('transfer_learning_model.h5')
|
|
|
|
|
|
def preprocess_image(image):
|
|
|
img = image.resize((299, 299))
|
|
|
img_array = np.array(img) / 255.0
|
|
|
img_array = np.expand_dims(img_array, axis=0)
|
|
|
return img, img_array
|
|
|
|
|
|
def predict(image):
|
|
|
model = load_model()
|
|
|
_, processed_image = preprocess_image(image)
|
|
|
prediction = model.predict(processed_image)
|
|
|
class_names = ['Cardboard', 'Food Organics', 'Glass', 'Metal', 'Miscellaneous Trash', 'Paper', 'Plastic', 'Textile Trash', 'Vegetation']
|
|
|
return {class_names[i]: float(prediction[0][i]) for i in range(len(class_names))}
|
|
|
|
|
|
def run():
|
|
|
st.title('🔍 Waste Classification Prediction')
|
|
|
|
|
|
|
|
|
example_images = ['cig_package.jpg', 'stella.jpg', 'water_bottle.jpg', 'textile_shoes.jpg', 'organic_eggs.jpg', 'men_metalic_pose.jpg', 'normal_men.jpg','uno.jpg']
|
|
|
example_path = './visualization'
|
|
|
|
|
|
st.subheader("Choose an example image or upload your own:")
|
|
|
|
|
|
|
|
|
if 'selected_image' not in st.session_state:
|
|
|
st.session_state.selected_image = None
|
|
|
|
|
|
|
|
|
cols = st.columns(4)
|
|
|
for i, img_name in enumerate(example_images):
|
|
|
with cols[i % 4]:
|
|
|
img_path = os.path.join(example_path, img_name)
|
|
|
|
|
|
|
|
|
st.image(img_path, width=100, caption=f'Example {i+1}')
|
|
|
|
|
|
|
|
|
if st.button(f"Example {i+1}", key=f"example_{i}"):
|
|
|
st.session_state.selected_image = img_path
|
|
|
|
|
|
uploaded_file = st.file_uploader("Or upload your own image", type=["jpg", "jpeg", "png"])
|
|
|
|
|
|
|
|
|
if uploaded_file is not None:
|
|
|
st.session_state.selected_image = uploaded_file
|
|
|
|
|
|
image = None
|
|
|
if st.session_state.selected_image is not None:
|
|
|
if isinstance(st.session_state.selected_image, str):
|
|
|
image = Image.open(st.session_state.selected_image).convert('RGB')
|
|
|
else:
|
|
|
image = Image.open(st.session_state.selected_image).convert('RGB')
|
|
|
|
|
|
if image:
|
|
|
|
|
|
col1, col2 = st.columns(2)
|
|
|
|
|
|
|
|
|
with col1:
|
|
|
st.subheader("Selected Image")
|
|
|
st.image(image, caption='Selected Image', use_column_width=True)
|
|
|
|
|
|
|
|
|
if st.button("Start Prediction"):
|
|
|
|
|
|
progress_bar = st.progress(0)
|
|
|
status_text = st.empty()
|
|
|
|
|
|
|
|
|
status_text.text('Preprocessing image...')
|
|
|
resized_image, _ = preprocess_image(image)
|
|
|
progress_bar.progress(33)
|
|
|
time.sleep(0.5)
|
|
|
|
|
|
|
|
|
with col2:
|
|
|
st.subheader("Resized Image (299x299 for Model)")
|
|
|
st.image(resized_image, caption='Resized Image for Prediction (299x299)', use_column_width=True)
|
|
|
|
|
|
|
|
|
status_text.text('Making prediction...')
|
|
|
prediction = predict(image)
|
|
|
progress_bar.progress(66)
|
|
|
time.sleep(0.5)
|
|
|
|
|
|
|
|
|
status_text.text('Analyzing results...')
|
|
|
predicted_class = max(prediction, key=prediction.get)
|
|
|
confidence = prediction[predicted_class]
|
|
|
progress_bar.progress(100)
|
|
|
time.sleep(0.5)
|
|
|
|
|
|
|
|
|
status_text.empty()
|
|
|
progress_bar.empty()
|
|
|
|
|
|
|
|
|
st.subheader("Prediction Results")
|
|
|
st.write(f"Predicted waste type: **{predicted_class}**")
|
|
|
st.write(f"Confidence: {confidence:.2%}")
|
|
|
|
|
|
|
|
|
fig = go.Figure(data=[go.Bar(
|
|
|
x=list(prediction.keys()),
|
|
|
y=list(prediction.values()),
|
|
|
marker=dict(
|
|
|
color=list(prediction.values()),
|
|
|
colorscale='Viridis',
|
|
|
colorbar=dict(title='Probability')
|
|
|
)
|
|
|
)])
|
|
|
fig.update_layout(
|
|
|
title='Prediction Probabilities',
|
|
|
xaxis_title='Waste Type',
|
|
|
yaxis_title='Probability',
|
|
|
height=500,
|
|
|
width=700
|
|
|
)
|
|
|
st.plotly_chart(fig)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
run()
|
|
|
|