File size: 5,182 Bytes
1f80ded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1823aef
1f80ded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import streamlit as st
import numpy as np
from PIL import Image
import plotly.graph_objects as go
import tensorflow as tf
import time
import os

# Load your trained model
@st.cache_resource
def load_model():
    return tf.keras.models.load_model('transfer_learning_model.h5')

def preprocess_image(image):
    img = image.resize((299, 299))
    img_array = np.array(img) / 255.0
    img_array = np.expand_dims(img_array, axis=0)
    return img, img_array

def predict(image):
    model = load_model()
    _, processed_image = preprocess_image(image)
    prediction = model.predict(processed_image)
    class_names = ['Cardboard', 'Food Organics', 'Glass', 'Metal', 'Miscellaneous Trash', 'Paper', 'Plastic', 'Textile Trash', 'Vegetation']
    return {class_names[i]: float(prediction[0][i]) for i in range(len(class_names))}

def run():
    st.title('🔍 Waste Classification Prediction')

    # Example images
    example_images = ['cig_package.jpg', 'stella.jpg', 'water_bottle.jpg', 'textile_shoes.jpg', 'organic_eggs.jpg', 'men_metalic_pose.jpg', 'normal_men.jpg','uno.jpg']
    example_path = './visualization'  # Set the path 

    st.subheader("Choose an example image or upload your own:")

    # Initialize session state for the selected image
    if 'selected_image' not in st.session_state:
        st.session_state.selected_image = None

    # Create columns for example images
    cols = st.columns(4)
    for i, img_name in enumerate(example_images):
        with cols[i % 4]:
            img_path = os.path.join(example_path, img_name)
            
            # Display the preview image under the button
            st.image(img_path, width=100, caption=f'Example {i+1}')
            
            # Create the button for each example
            if st.button(f"Example {i+1}", key=f"example_{i}"):
                st.session_state.selected_image = img_path

    uploaded_file = st.file_uploader("Or upload your own image", type=["jpg", "jpeg", "png"])

    # Use session state to store the selected or uploaded image
    if uploaded_file is not None:
        st.session_state.selected_image = uploaded_file

    image = None
    if st.session_state.selected_image is not None:
        if isinstance(st.session_state.selected_image, str):  # Example image case
            image = Image.open(st.session_state.selected_image).convert('RGB')
        else:  # Uploaded image case
            image = Image.open(st.session_state.selected_image).convert('RGB')

    if image:
        # Create two columns for images
        col1, col2 = st.columns(2)

        # Display original image in the left column
        with col1:
            st.subheader("Selected Image")
            st.image(image, caption='Selected Image', use_column_width=True)

        # Add a button to start prediction
        if st.button("Start Prediction"):
            # Progress and status indicators
            progress_bar = st.progress(0)
            status_text = st.empty()

            # Preprocess the image
            status_text.text('Preprocessing image...')
            resized_image, _ = preprocess_image(image)
            progress_bar.progress(33)
            time.sleep(0.5)  # Simulate processing time

            # Display resized image in the right column
            with col2:
                st.subheader("Resized Image (299x299 for Model)")
                st.image(resized_image, caption='Resized Image for Prediction (299x299)', use_column_width=True)

            # Make prediction
            status_text.text('Making prediction...')
            prediction = predict(image)
            progress_bar.progress(66)
            time.sleep(0.5)  # Simulate processing time

            # Analyze results
            status_text.text('Analyzing results...')
            predicted_class = max(prediction, key=prediction.get)
            confidence = prediction[predicted_class]
            progress_bar.progress(100)
            time.sleep(0.5)  # Simulate processing time

            # Clear the status text and progress bar
            status_text.empty()
            progress_bar.empty()

            # Display prediction results under the images
            st.subheader("Prediction Results")
            st.write(f"Predicted waste type: **{predicted_class}**")
            st.write(f"Confidence: {confidence:.2%}")

            # Display vertical bar chart of probabilities using Plotly
            fig = go.Figure(data=[go.Bar(
                x=list(prediction.keys()),
                y=list(prediction.values()),
                marker=dict(
                    color=list(prediction.values()),
                    colorscale='Viridis',
                    colorbar=dict(title='Probability')
                )
            )])
            fig.update_layout(
                title='Prediction Probabilities',
                xaxis_title='Waste Type',
                yaxis_title='Probability',
                height=500,
                width=700
            )
            st.plotly_chart(fig)

if __name__ == "__main__":
    run()