amariayudha's picture
Upload prediction.py
1823aef verified
import streamlit as st
import numpy as np
from PIL import Image
import plotly.graph_objects as go
import tensorflow as tf
import time
import os
# Load your trained model
@st.cache_resource
def load_model():
return tf.keras.models.load_model('transfer_learning_model.h5')
def preprocess_image(image):
img = image.resize((299, 299))
img_array = np.array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
return img, img_array
def predict(image):
model = load_model()
_, processed_image = preprocess_image(image)
prediction = model.predict(processed_image)
class_names = ['Cardboard', 'Food Organics', 'Glass', 'Metal', 'Miscellaneous Trash', 'Paper', 'Plastic', 'Textile Trash', 'Vegetation']
return {class_names[i]: float(prediction[0][i]) for i in range(len(class_names))}
def run():
st.title('🔍 Waste Classification Prediction')
# Example images
example_images = ['cig_package.jpg', 'stella.jpg', 'water_bottle.jpg', 'textile_shoes.jpg', 'organic_eggs.jpg', 'men_metalic_pose.jpg', 'normal_men.jpg','uno.jpg']
example_path = './visualization' # Set the path
st.subheader("Choose an example image or upload your own:")
# Initialize session state for the selected image
if 'selected_image' not in st.session_state:
st.session_state.selected_image = None
# Create columns for example images
cols = st.columns(4)
for i, img_name in enumerate(example_images):
with cols[i % 4]:
img_path = os.path.join(example_path, img_name)
# Display the preview image under the button
st.image(img_path, width=100, caption=f'Example {i+1}')
# Create the button for each example
if st.button(f"Example {i+1}", key=f"example_{i}"):
st.session_state.selected_image = img_path
uploaded_file = st.file_uploader("Or upload your own image", type=["jpg", "jpeg", "png"])
# Use session state to store the selected or uploaded image
if uploaded_file is not None:
st.session_state.selected_image = uploaded_file
image = None
if st.session_state.selected_image is not None:
if isinstance(st.session_state.selected_image, str): # Example image case
image = Image.open(st.session_state.selected_image).convert('RGB')
else: # Uploaded image case
image = Image.open(st.session_state.selected_image).convert('RGB')
if image:
# Create two columns for images
col1, col2 = st.columns(2)
# Display original image in the left column
with col1:
st.subheader("Selected Image")
st.image(image, caption='Selected Image', use_column_width=True)
# Add a button to start prediction
if st.button("Start Prediction"):
# Progress and status indicators
progress_bar = st.progress(0)
status_text = st.empty()
# Preprocess the image
status_text.text('Preprocessing image...')
resized_image, _ = preprocess_image(image)
progress_bar.progress(33)
time.sleep(0.5) # Simulate processing time
# Display resized image in the right column
with col2:
st.subheader("Resized Image (299x299 for Model)")
st.image(resized_image, caption='Resized Image for Prediction (299x299)', use_column_width=True)
# Make prediction
status_text.text('Making prediction...')
prediction = predict(image)
progress_bar.progress(66)
time.sleep(0.5) # Simulate processing time
# Analyze results
status_text.text('Analyzing results...')
predicted_class = max(prediction, key=prediction.get)
confidence = prediction[predicted_class]
progress_bar.progress(100)
time.sleep(0.5) # Simulate processing time
# Clear the status text and progress bar
status_text.empty()
progress_bar.empty()
# Display prediction results under the images
st.subheader("Prediction Results")
st.write(f"Predicted waste type: **{predicted_class}**")
st.write(f"Confidence: {confidence:.2%}")
# Display vertical bar chart of probabilities using Plotly
fig = go.Figure(data=[go.Bar(
x=list(prediction.keys()),
y=list(prediction.values()),
marker=dict(
color=list(prediction.values()),
colorscale='Viridis',
colorbar=dict(title='Probability')
)
)])
fig.update_layout(
title='Prediction Probabilities',
xaxis_title='Waste Type',
yaxis_title='Probability',
height=500,
width=700
)
st.plotly_chart(fig)
if __name__ == "__main__":
run()