File size: 6,848 Bytes
100cb22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aadcc42
100cb22
 
 
0010ba2
 
 
 
91241f5
0010ba2
 
 
 
 
 
 
 
 
 
 
 
100cb22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import streamlit as st
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
from PIL import Image
from sklearn.preprocessing import StandardScaler, LabelEncoder

st.set_page_config(layout="centered")

# Add custom CSS for background image and styling
# Add custom CSS for background image and styling
st.markdown("""
    <style>
        .stApp {
            background-image: url("https://as1.ftcdn.net/jpg/01/82/21/76/1000_F_182217694_DZi3Ytqsb0RpWQb9dwC7NLFwkwqgnh0r.jpg");
            background-size: cover;
            background-position: center;
            background-repeat: no-repeat;
            height: auto;  /* Allows the page to expand for scrolling */
            overflow: auto;  /* Enables scrolling if the page content overflows */
            # position : relative
        }
        /* Adjust opacity of overlay to make content more visible */
        .stApp::before {
            content: "";
            position: absolute;
            top: 0;
            left: 0;
            width: 100%;
            height: 100%;
            background-color: rgba(255, 255, 255, 0.8);  /* Slightly higher opacity */
            z-index: -1;
        }
        /* Ensure content appears above the overlay */
        .stApp > * {
            position: relative;
            z-index: 2;
        }
        /* Ensure the dataframe is visible */
        .dataframe {
            background-color: rgba(255, 255, 255, 0.9) !important;
            z-index: 3;
        }
        /* Style text elements for better visibility */
        h1, h3, span, div {
            text-shadow: 1px 1px 2px rgba(255, 255, 255, 0.2);
        }
            
        /* Custom CSS for select box heading */
        div.stSelectbox > label {
            color: #000000 !important;  /* Change to your desired color */
            # background-color: black !important;  /* Background color of the dropdown */
            font-size: 24px !important;  /* Change font size */
            font-weight: bold !important;  /* Make text bold */
        }
        /* Custom CSS for image caption */
        .custom-caption {
            color: #000000 !important;  /* Change to your desired color */
            font-size: 24px !important;  /* Optional: Change font size */
            text-align: center;  /* Center-align the caption */
        }
            
        .stMainBlockContainer {
            background-color: white !important;  /* Background color of the dropdown */
        }
            
    </style>
""", unsafe_allow_html=True)


# Custom title styling functions
def colored_title(text, color):
    st.markdown(f"<h1 style='color: {color};'>{text}</h1>", unsafe_allow_html=True)

def colored_subheader(text, color):
    st.markdown(f"<h3 style='color: {color};'>{text}</h3>", unsafe_allow_html=True)

def colored_text(text, color):
    st.markdown(f"<span style='color: {color};'>{text}</span>", unsafe_allow_html=True)

class ClassNet(nn.Module):
    
    def __init__(self):
        super(ClassNet, self).__init__()

        self.conv1 = nn.Conv2d(3,6,3)
        self.conv2 = nn.Conv2d(6,16,5)
        self.maxpool1 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(16,32,5)
        self.maxpool2 = nn.MaxPool2d(2)

        self.fc1 = nn.Linear(512,256)
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256,128)
        self.dropout2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(128,43)
    def forward(self,input):

        x = F.relu(self.conv1(input))
        x = F.relu(self.conv2(x))
        x = self.maxpool1(x)
        x = F.relu(self.conv3(x))
        x = self.maxpool2(x)

        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        output = self.fc3(x)

        return output

@st.cache_resource
def load_model():

    model = ClassNet()
    try:
        state_dict = torch.load('traffic_light_model_weights.pth', map_location=torch.device('cpu'))
        model.load_state_dict(state_dict)
        model.eval()
        return model
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return None

@st.cache_data
def load_data():

    y_test = pd.read_csv('traffic_lights/Test.csv')

    imgs = y_test["Path"].values
    labels = y_test["ClassId"].values
    # st.write(imgs)
    test_images = []
    for img in imgs:
        if isinstance(img,str):
            image = Image.open('traffic_lights/'+img)
            image = image.resize([30, 30])
            test_images.append(np.array(image))

    # Load meta images
    meta_images = {}
    meta_folder = 'traffic_lights/Meta'  # Replace with the path to your meta folder
    for class_id in range(43):
        meta_image_path = os.path.join(meta_folder, f"{class_id}.png")  # Assuming meta images are named as 0.png, 1.png, etc.
        if os.path.exists(meta_image_path):
            meta_images[class_id] = Image.open(meta_image_path)

    return test_images, labels, meta_images

def main():
    colored_title("Traffic Symbol Prediction", "black")

    # Load data
    test_images, labels, meta_images = load_data()

    # Display test images for selection
    colored_subheader("Select an Image for Prediction:", "black")
    selected_index = st.selectbox("Select an image by index:", options=range(len(test_images)), index=0)

    # Display the selected test image
    st.image(test_images[selected_index], width=150)

    st.markdown(
        f'<p class="custom-caption">Selected Test Image (Class: {labels[selected_index]})</p>',
        unsafe_allow_html=True
    )

    # Predict button
    if st.button("Predict"):
        model = load_model()
        if model is not None:
            # Preprocess the selected image
            image = test_images[selected_index] / 255.0  # Normalize
            image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)  # Convert to tensor

            # Make prediction
            with torch.no_grad():
                output = model(image)
                predicted_class = torch.argmax(output, dim=1).item()

            # Display prediction result
            colored_subheader("Prediction Results:", "green")
            colored_text(f"Predicted Class: {predicted_class}", "green")

            # Display the corresponding meta image
            if predicted_class in meta_images:
                st.image(meta_images[predicted_class], width=150)
                st.markdown(
                    f'<p class="custom-caption">Clear Image for Class: {predicted_class}</p>',
                    unsafe_allow_html=True
                )
            else:
                st.warning(f"No clear image found for class {predicted_class} in the meta folder.")  
            
if __name__ == "__main__":
    main()