| | import streamlit as st |
| | import pandas as pd |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import os |
| | from PIL import Image |
| | from sklearn.preprocessing import StandardScaler, LabelEncoder |
| |
|
| | st.set_page_config(layout="centered") |
| |
|
| | |
| | |
| | st.markdown(""" |
| | <style> |
| | .stApp { |
| | background-image: url("https://as1.ftcdn.net/jpg/01/82/21/76/1000_F_182217694_DZi3Ytqsb0RpWQb9dwC7NLFwkwqgnh0r.jpg"); |
| | background-size: cover; |
| | background-position: center; |
| | background-repeat: no-repeat; |
| | height: auto; /* Allows the page to expand for scrolling */ |
| | overflow: auto; /* Enables scrolling if the page content overflows */ |
| | # position : relative |
| | } |
| | |
| | /* Adjust opacity of overlay to make content more visible */ |
| | .stApp::before { |
| | content: ""; |
| | position: absolute; |
| | top: 0; |
| | left: 0; |
| | width: 100%; |
| | height: 100%; |
| | background-color: rgba(255, 255, 255, 0.8); /* Slightly higher opacity */ |
| | z-index: -1; |
| | } |
| | |
| | /* Ensure content appears above the overlay */ |
| | .stApp > * { |
| | position: relative; |
| | z-index: 2; |
| | } |
| | |
| | /* Ensure the dataframe is visible */ |
| | .dataframe { |
| | background-color: rgba(255, 255, 255, 0.9) !important; |
| | z-index: 3; |
| | } |
| | |
| | /* Style text elements for better visibility */ |
| | h1, h3, span, div { |
| | text-shadow: 1px 1px 2px rgba(255, 255, 255, 0.2); |
| | } |
| | |
| | /* Custom CSS for select box heading */ |
| | div.stSelectbox > label { |
| | color: #000000 !important; /* Change to your desired color */ |
| | # background-color: black !important; /* Background color of the dropdown */ |
| | font-size: 24px !important; /* Change font size */ |
| | font-weight: bold !important; /* Make text bold */ |
| | } |
| | |
| | /* Custom CSS for image caption */ |
| | .custom-caption { |
| | color: #000000 !important; /* Change to your desired color */ |
| | font-size: 24px !important; /* Optional: Change font size */ |
| | text-align: center; /* Center-align the caption */ |
| | } |
| | |
| | .stMainBlockContainer { |
| | background-color: white !important; /* Background color of the dropdown */ |
| | } |
| | |
| | </style> |
| | """, unsafe_allow_html=True) |
| |
|
| |
|
| | |
| | def colored_title(text, color): |
| | st.markdown(f"<h1 style='color: {color};'>{text}</h1>", unsafe_allow_html=True) |
| |
|
| | def colored_subheader(text, color): |
| | st.markdown(f"<h3 style='color: {color};'>{text}</h3>", unsafe_allow_html=True) |
| |
|
| | def colored_text(text, color): |
| | st.markdown(f"<span style='color: {color};'>{text}</span>", unsafe_allow_html=True) |
| |
|
| | class ClassNet(nn.Module): |
| | |
| | def __init__(self): |
| | super(ClassNet, self).__init__() |
| |
|
| | self.conv1 = nn.Conv2d(3,6,3) |
| | self.conv2 = nn.Conv2d(6,16,5) |
| | self.maxpool1 = nn.MaxPool2d(2) |
| | self.conv3 = nn.Conv2d(16,32,5) |
| | self.maxpool2 = nn.MaxPool2d(2) |
| |
|
| | self.fc1 = nn.Linear(512,256) |
| | self.dropout1 = nn.Dropout(0.5) |
| | self.fc2 = nn.Linear(256,128) |
| | self.dropout2 = nn.Dropout(0.5) |
| | self.fc3 = nn.Linear(128,43) |
| | def forward(self,input): |
| |
|
| | x = F.relu(self.conv1(input)) |
| | x = F.relu(self.conv2(x)) |
| | x = self.maxpool1(x) |
| | x = F.relu(self.conv3(x)) |
| | x = self.maxpool2(x) |
| |
|
| | x = torch.flatten(x,1) |
| | x = F.relu(self.fc1(x)) |
| | x = self.dropout1(x) |
| | x = F.relu(self.fc2(x)) |
| | x = self.dropout2(x) |
| | output = self.fc3(x) |
| |
|
| | return output |
| |
|
| | @st.cache_resource |
| | def load_model(): |
| |
|
| | model = ClassNet() |
| | try: |
| | state_dict = torch.load('traffic_light_model_weights.pth', map_location=torch.device('cpu')) |
| | model.load_state_dict(state_dict) |
| | model.eval() |
| | return model |
| | except Exception as e: |
| | st.error(f"Error loading model: {str(e)}") |
| | return None |
| |
|
| | @st.cache_data |
| | def load_data(): |
| |
|
| | y_test = pd.read_csv('traffic_lights/Test.csv') |
| |
|
| | imgs = y_test["Path"].values |
| | labels = y_test["ClassId"].values |
| |
|
| | test_images = [] |
| | for img in imgs: |
| | if isinstance(img,str): |
| | image = Image.open('traffic_lights/'+img) |
| | image = image.resize([30, 30]) |
| | test_images.append(np.array(image)) |
| |
|
| | |
| | meta_images = {} |
| | meta_folder = 'traffic_lights/Meta' |
| | for class_id in range(43): |
| | meta_image_path = os.path.join(meta_folder, f"{class_id}.png") |
| | if os.path.exists(meta_image_path): |
| | meta_images[class_id] = Image.open(meta_image_path) |
| |
|
| | return test_images, labels, meta_images |
| |
|
| | def main(): |
| | colored_title("Traffic Symbol Prediction", "black") |
| |
|
| | |
| | test_images, labels, meta_images = load_data() |
| |
|
| | |
| | colored_subheader("Select an Image for Prediction:", "black") |
| | selected_index = st.selectbox("Select an image by index:", options=range(len(test_images)), index=0) |
| |
|
| | |
| | st.image(test_images[selected_index], width=150) |
| |
|
| | st.markdown( |
| | f'<p class="custom-caption">Selected Test Image (Class: {labels[selected_index]})</p>', |
| | unsafe_allow_html=True |
| | ) |
| |
|
| | |
| | if st.button("Predict"): |
| | model = load_model() |
| | if model is not None: |
| | |
| | image = test_images[selected_index] / 255.0 |
| | image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) |
| |
|
| | |
| | with torch.no_grad(): |
| | output = model(image) |
| | predicted_class = torch.argmax(output, dim=1).item() |
| |
|
| | |
| | colored_subheader("Prediction Results:", "green") |
| | colored_text(f"Predicted Class: {predicted_class}", "green") |
| |
|
| | |
| | if predicted_class in meta_images: |
| | st.image(meta_images[predicted_class], width=150) |
| | st.markdown( |
| | f'<p class="custom-caption">Clear Image for Class: {predicted_class}</p>', |
| | unsafe_allow_html=True |
| | ) |
| | else: |
| | st.warning(f"No clear image found for class {predicted_class} in the meta folder.") |
| | |
| | if __name__ == "__main__": |
| | main() |