YAMITEK commited on
Commit
100cb22
·
verified ·
1 Parent(s): 0a8cc79

Upload 5 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ traffic_sign_document.pdf filter=lfs diff=lfs merge=lfs -text
Traffic_light_prediction_notebook.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ import os
8
+ from PIL import Image
9
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
10
+
11
+ st.set_page_config(layout="centered")
12
+
13
+ # Add custom CSS for background image and styling
14
+ # Add custom CSS for background image and styling
15
+ st.markdown("""
16
+ <style>
17
+ .stApp {
18
+ background-image: url("https://as1.ftcdn.net/jpg/01/82/21/76/1000_F_182217694_DZi3Ytqsb0RpWQb9dwC7NLFwkwqgnh0r.jpg");
19
+ background-size: cover;
20
+ background-position: center;
21
+ background-repeat: no-repeat;
22
+ height: auto; /* Allows the page to expand for scrolling */
23
+ overflow: auto; /* Enables scrolling if the page content overflows */
24
+ # position : relative
25
+ }
26
+
27
+ /* Adjust opacity of overlay to make content more visible */
28
+ .stApp::before {
29
+ content: "";
30
+ position: absolute;
31
+ top: 0;
32
+ left: 0;
33
+ width: 100%;
34
+ height: 100%;
35
+ background-color: rgba(255, 255, 255, 0.8); /* Slightly higher opacity */
36
+ z-index: -1;
37
+ }
38
+
39
+ /* Ensure content appears above the overlay */
40
+ .stApp > * {
41
+ position: relative;
42
+ z-index: 2;
43
+ }
44
+
45
+ /* Ensure the dataframe is visible */
46
+ .dataframe {
47
+ background-color: rgba(255, 255, 255, 0.9) !important;
48
+ z-index: 3;
49
+ }
50
+
51
+ /* Style text elements for better visibility */
52
+ h1, h3, span, div {
53
+ text-shadow: 1px 1px 2px rgba(255, 255, 255, 0.2);
54
+ }
55
+
56
+ /* Custom CSS for select box heading */
57
+ div.stSelectbox > label {
58
+ color: #000000 !important; /* Change to your desired color */
59
+ # background-color: black !important; /* Background color of the dropdown */
60
+ font-size: 24px !important; /* Change font size */
61
+ font-weight: bold !important; /* Make text bold */
62
+ }
63
+
64
+ /* Custom CSS for image caption */
65
+ .custom-caption {
66
+ color: #000000 !important; /* Change to your desired color */
67
+ font-size: 24px !important; /* Optional: Change font size */
68
+ text-align: center; /* Center-align the caption */
69
+ }
70
+
71
+ .stMainBlockContainer {
72
+ background-color: white !important; /* Background color of the dropdown */
73
+ }
74
+
75
+ </style>
76
+ """, unsafe_allow_html=True)
77
+
78
+
79
+ # Custom title styling functions
80
+ def colored_title(text, color):
81
+ st.markdown(f"<h1 style='color: {color};'>{text}</h1>", unsafe_allow_html=True)
82
+
83
+ def colored_subheader(text, color):
84
+ st.markdown(f"<h3 style='color: {color};'>{text}</h3>", unsafe_allow_html=True)
85
+
86
+ def colored_text(text, color):
87
+ st.markdown(f"<span style='color: {color};'>{text}</span>", unsafe_allow_html=True)
88
+
89
+ class ClassNet(nn.Module):
90
+
91
+ def __init__(self):
92
+ super(ClassNet, self).__init__()
93
+
94
+ self.conv1 = nn.Conv2d(3,6,3)
95
+ self.conv2 = nn.Conv2d(6,16,5)
96
+ self.maxpool1 = nn.MaxPool2d(2)
97
+ self.conv3 = nn.Conv2d(16,32,5)
98
+ self.maxpool2 = nn.MaxPool2d(2)
99
+
100
+ self.fc1 = nn.Linear(512,256)
101
+ self.dropout1 = nn.Dropout(0.5)
102
+ self.fc2 = nn.Linear(256,128)
103
+ self.dropout2 = nn.Dropout(0.5)
104
+ self.fc3 = nn.Linear(128,43)
105
+ def forward(self,input):
106
+
107
+ x = F.relu(self.conv1(input))
108
+ x = F.relu(self.conv2(x))
109
+ x = self.maxpool1(x)
110
+ x = F.relu(self.conv3(x))
111
+ x = self.maxpool2(x)
112
+
113
+ x = torch.flatten(x,1)
114
+ x = F.relu(self.fc1(x))
115
+ x = self.dropout1(x)
116
+ x = F.relu(self.fc2(x))
117
+ x = self.dropout2(x)
118
+ output = self.fc3(x)
119
+
120
+ return output
121
+
122
+ @st.cache_resource
123
+ def load_model():
124
+
125
+ model = ClassNet()
126
+ try:
127
+ state_dict = torch.load('traffic_light_model_weights.pth', map_location=torch.device('cpu'))
128
+ model.load_state_dict(state_dict)
129
+ model.eval()
130
+ return model
131
+ except Exception as e:
132
+ st.error(f"Error loading model: {str(e)}")
133
+ return None
134
+
135
+ @st.cache_data
136
+ def load_data():
137
+
138
+ y_test = pd.read_csv('traffic_lights/Test.csv')
139
+
140
+ imgs = y_test["Path"].values
141
+ labels = y_test["ClassId"].values
142
+
143
+ test_images = []
144
+ for img in imgs:
145
+ if isinstance(img,str):
146
+ image = Image.open('traffic_lights/'+img)
147
+ image = image.resize([30, 30])
148
+ test_images.append(np.array(image))
149
+
150
+ # Load meta images
151
+ meta_images = {}
152
+ meta_folder = 'traffic_lights/Meta' # Replace with the path to your meta folder
153
+ for class_id in range(43):
154
+ meta_image_path = os.path.join(meta_folder, f"{class_id}.png") # Assuming meta images are named as 0.png, 1.png, etc.
155
+ if os.path.exists(meta_image_path):
156
+ meta_images[class_id] = Image.open(meta_image_path)
157
+
158
+ return test_images, labels, meta_images
159
+
160
+ def main():
161
+ colored_title("Traffic Symbol Prediction", "black")
162
+
163
+ # Load data
164
+ test_images, labels, meta_images = load_data()
165
+
166
+ # Display test images for selection
167
+ colored_subheader("Select an Image for Prediction:", "black")
168
+ selected_index = st.selectbox("Select an image by index:", options=range(len(test_images)), index=0)
169
+
170
+ # Display the selected test image
171
+ st.image(test_images[selected_index], width=150)
172
+
173
+ st.markdown(
174
+ f'<p class="custom-caption">Selected Test Image (Class: {labels[selected_index]})</p>',
175
+ unsafe_allow_html=True
176
+ )
177
+
178
+ # Predict button
179
+ if st.button("Predict"):
180
+ model = load_model()
181
+ if model is not None:
182
+ # Preprocess the selected image
183
+ image = test_images[selected_index] / 255.0 # Normalize
184
+ image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) # Convert to tensor
185
+
186
+ # Make prediction
187
+ with torch.no_grad():
188
+ output = model(image)
189
+ predicted_class = torch.argmax(output, dim=1).item()
190
+
191
+ # Display prediction result
192
+ colored_subheader("Prediction Results:", "green")
193
+ colored_text(f"Predicted Class: {predicted_class}", "green")
194
+
195
+ # Display the corresponding meta image
196
+ if predicted_class in meta_images:
197
+ st.image(meta_images[predicted_class], width=150)
198
+ st.markdown(
199
+ f'<p class="custom-caption">Clear Image for Class: {predicted_class}</p>',
200
+ unsafe_allow_html=True
201
+ )
202
+ else:
203
+ st.warning(f"No clear image found for class {predicted_class} in the meta folder.")
204
+
205
+ if __name__ == "__main__":
206
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ streamlit==1.25.0
2
+ pandas==1.5.3
3
+ numpy==1.24.3
4
+ scikit-learn==1.2.2
5
+ torch==2.0.1
6
+ torchvision==0.15.2
7
+ torchaudio==2.0.2
traffic_light_model_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20755445924104845831325f77bef167d20929c9e5a1521b9bcf2955e0ef2304
3
+ size 745522
traffic_sign_document.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe36f16740a86d72283919b6e698c4ffd8b69346b64008039a3474421a21fcf7
3
+ size 109756