import streamlit as st
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
from PIL import Image
from sklearn.preprocessing import StandardScaler, LabelEncoder
st.set_page_config(layout="centered")
# Add custom CSS for background image and styling
# Add custom CSS for background image and styling
st.markdown("""
""", unsafe_allow_html=True)
# Custom title styling functions
def colored_title(text, color):
st.markdown(f"
{text}
", unsafe_allow_html=True)
def colored_subheader(text, color):
st.markdown(f"{text}
", unsafe_allow_html=True)
def colored_text(text, color):
st.markdown(f"{text}", unsafe_allow_html=True)
class ClassNet(nn.Module):
def __init__(self):
super(ClassNet, self).__init__()
self.conv1 = nn.Conv2d(3,6,3)
self.conv2 = nn.Conv2d(6,16,5)
self.maxpool1 = nn.MaxPool2d(2)
self.conv3 = nn.Conv2d(16,32,5)
self.maxpool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(512,256)
self.dropout1 = nn.Dropout(0.5)
self.fc2 = nn.Linear(256,128)
self.dropout2 = nn.Dropout(0.5)
self.fc3 = nn.Linear(128,43)
def forward(self,input):
x = F.relu(self.conv1(input))
x = F.relu(self.conv2(x))
x = self.maxpool1(x)
x = F.relu(self.conv3(x))
x = self.maxpool2(x)
x = torch.flatten(x,1)
x = F.relu(self.fc1(x))
x = self.dropout1(x)
x = F.relu(self.fc2(x))
x = self.dropout2(x)
output = self.fc3(x)
return output
@st.cache_resource
def load_model():
model = ClassNet()
try:
state_dict = torch.load('traffic_light_model_weights.pth', map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()
return model
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None
@st.cache_data
def load_data():
y_test = pd.read_csv('traffic_lights/Test.csv')
imgs = y_test["Path"].values
labels = y_test["ClassId"].values
# st.write(imgs)
test_images = []
for img in imgs:
if isinstance(img,str):
image = Image.open('traffic_lights/'+img)
image = image.resize([30, 30])
test_images.append(np.array(image))
# Load meta images
meta_images = {}
meta_folder = 'traffic_lights/Meta' # Replace with the path to your meta folder
for class_id in range(43):
meta_image_path = os.path.join(meta_folder, f"{class_id}.png") # Assuming meta images are named as 0.png, 1.png, etc.
if os.path.exists(meta_image_path):
meta_images[class_id] = Image.open(meta_image_path)
return test_images, labels, meta_images
def main():
colored_title("Traffic Symbol Prediction", "black")
# Load data
test_images, labels, meta_images = load_data()
# Display test images for selection
colored_subheader("Select an Image for Prediction:", "black")
selected_index = st.selectbox("Select an image by index:", options=range(len(test_images)), index=0)
# Display the selected test image
st.image(test_images[selected_index], width=150)
st.markdown(
f'Selected Test Image (Class: {labels[selected_index]})
',
unsafe_allow_html=True
)
# Predict button
if st.button("Predict"):
model = load_model()
if model is not None:
# Preprocess the selected image
image = test_images[selected_index] / 255.0 # Normalize
image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) # Convert to tensor
# Make prediction
with torch.no_grad():
output = model(image)
predicted_class = torch.argmax(output, dim=1).item()
# Display prediction result
colored_subheader("Prediction Results:", "green")
colored_text(f"Predicted Class: {predicted_class}", "green")
# Display the corresponding meta image
if predicted_class in meta_images:
st.image(meta_images[predicted_class], width=150)
st.markdown(
f'Clear Image for Class: {predicted_class}
',
unsafe_allow_html=True
)
else:
st.warning(f"No clear image found for class {predicted_class} in the meta folder.")
if __name__ == "__main__":
main()