import streamlit as st import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import os from PIL import Image from sklearn.preprocessing import StandardScaler, LabelEncoder st.set_page_config(layout="centered") # Add custom CSS for background image and styling # Add custom CSS for background image and styling st.markdown(""" """, unsafe_allow_html=True) # Custom title styling functions def colored_title(text, color): st.markdown(f"

{text}

", unsafe_allow_html=True) def colored_subheader(text, color): st.markdown(f"

{text}

", unsafe_allow_html=True) def colored_text(text, color): st.markdown(f"{text}", unsafe_allow_html=True) class ClassNet(nn.Module): def __init__(self): super(ClassNet, self).__init__() self.conv1 = nn.Conv2d(3,6,3) self.conv2 = nn.Conv2d(6,16,5) self.maxpool1 = nn.MaxPool2d(2) self.conv3 = nn.Conv2d(16,32,5) self.maxpool2 = nn.MaxPool2d(2) self.fc1 = nn.Linear(512,256) self.dropout1 = nn.Dropout(0.5) self.fc2 = nn.Linear(256,128) self.dropout2 = nn.Dropout(0.5) self.fc3 = nn.Linear(128,43) def forward(self,input): x = F.relu(self.conv1(input)) x = F.relu(self.conv2(x)) x = self.maxpool1(x) x = F.relu(self.conv3(x)) x = self.maxpool2(x) x = torch.flatten(x,1) x = F.relu(self.fc1(x)) x = self.dropout1(x) x = F.relu(self.fc2(x)) x = self.dropout2(x) output = self.fc3(x) return output @st.cache_resource def load_model(): model = ClassNet() try: state_dict = torch.load('traffic_light_model_weights.pth', map_location=torch.device('cpu')) model.load_state_dict(state_dict) model.eval() return model except Exception as e: st.error(f"Error loading model: {str(e)}") return None @st.cache_data def load_data(): y_test = pd.read_csv('traffic_lights/Test.csv') imgs = y_test["Path"].values labels = y_test["ClassId"].values # st.write(imgs) test_images = [] for img in imgs: if isinstance(img,str): image = Image.open('traffic_lights/'+img) image = image.resize([30, 30]) test_images.append(np.array(image)) # Load meta images meta_images = {} meta_folder = 'traffic_lights/Meta' # Replace with the path to your meta folder for class_id in range(43): meta_image_path = os.path.join(meta_folder, f"{class_id}.png") # Assuming meta images are named as 0.png, 1.png, etc. if os.path.exists(meta_image_path): meta_images[class_id] = Image.open(meta_image_path) return test_images, labels, meta_images def main(): colored_title("Traffic Symbol Prediction", "black") # Load data test_images, labels, meta_images = load_data() # Display test images for selection colored_subheader("Select an Image for Prediction:", "black") selected_index = st.selectbox("Select an image by index:", options=range(len(test_images)), index=0) # Display the selected test image st.image(test_images[selected_index], width=150) st.markdown( f'

Selected Test Image (Class: {labels[selected_index]})

', unsafe_allow_html=True ) # Predict button if st.button("Predict"): model = load_model() if model is not None: # Preprocess the selected image image = test_images[selected_index] / 255.0 # Normalize image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) # Convert to tensor # Make prediction with torch.no_grad(): output = model(image) predicted_class = torch.argmax(output, dim=1).item() # Display prediction result colored_subheader("Prediction Results:", "green") colored_text(f"Predicted Class: {predicted_class}", "green") # Display the corresponding meta image if predicted_class in meta_images: st.image(meta_images[predicted_class], width=150) st.markdown( f'

Clear Image for Class: {predicted_class}

', unsafe_allow_html=True ) else: st.warning(f"No clear image found for class {predicted_class} in the meta folder.") if __name__ == "__main__": main()