Spaces:
Build error
Build error
| import streamlit as st | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import os | |
| from PIL import Image | |
| from sklearn.preprocessing import StandardScaler, LabelEncoder | |
| st.set_page_config(layout="centered") | |
| # Add custom CSS for background image and styling | |
| # Add custom CSS for background image and styling | |
| st.markdown(""" | |
| <style> | |
| .stApp { | |
| background-image: url("https://as1.ftcdn.net/jpg/01/82/21/76/1000_F_182217694_DZi3Ytqsb0RpWQb9dwC7NLFwkwqgnh0r.jpg"); | |
| background-size: cover; | |
| background-position: center; | |
| background-repeat: no-repeat; | |
| height: auto; /* Allows the page to expand for scrolling */ | |
| overflow: auto; /* Enables scrolling if the page content overflows */ | |
| # position : relative | |
| } | |
| /* Adjust opacity of overlay to make content more visible */ | |
| .stApp::before { | |
| content: ""; | |
| position: absolute; | |
| top: 0; | |
| left: 0; | |
| width: 100%; | |
| height: 100%; | |
| background-color: rgba(255, 255, 255, 0.8); /* Slightly higher opacity */ | |
| z-index: -1; | |
| } | |
| /* Ensure content appears above the overlay */ | |
| .stApp > * { | |
| position: relative; | |
| z-index: 2; | |
| } | |
| /* Ensure the dataframe is visible */ | |
| .dataframe { | |
| background-color: rgba(255, 255, 255, 0.9) !important; | |
| z-index: 3; | |
| } | |
| /* Style text elements for better visibility */ | |
| h1, h3, span, div { | |
| text-shadow: 1px 1px 2px rgba(255, 255, 255, 0.2); | |
| } | |
| /* Custom CSS for select box heading */ | |
| div.stSelectbox > label { | |
| color: #000000 !important; /* Change to your desired color */ | |
| # background-color: black !important; /* Background color of the dropdown */ | |
| font-size: 24px !important; /* Change font size */ | |
| font-weight: bold !important; /* Make text bold */ | |
| } | |
| /* Custom CSS for image caption */ | |
| .custom-caption { | |
| color: #000000 !important; /* Change to your desired color */ | |
| font-size: 24px !important; /* Optional: Change font size */ | |
| text-align: center; /* Center-align the caption */ | |
| } | |
| .stMainBlockContainer { | |
| background-color: white !important; /* Background color of the dropdown */ | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Custom title styling functions | |
| def colored_title(text, color): | |
| st.markdown(f"<h1 style='color: {color};'>{text}</h1>", unsafe_allow_html=True) | |
| def colored_subheader(text, color): | |
| st.markdown(f"<h3 style='color: {color};'>{text}</h3>", unsafe_allow_html=True) | |
| def colored_text(text, color): | |
| st.markdown(f"<span style='color: {color};'>{text}</span>", unsafe_allow_html=True) | |
| class ClassNet(nn.Module): | |
| def __init__(self): | |
| super(ClassNet, self).__init__() | |
| self.conv1 = nn.Conv2d(3,6,3) | |
| self.conv2 = nn.Conv2d(6,16,5) | |
| self.maxpool1 = nn.MaxPool2d(2) | |
| self.conv3 = nn.Conv2d(16,32,5) | |
| self.maxpool2 = nn.MaxPool2d(2) | |
| self.fc1 = nn.Linear(512,256) | |
| self.dropout1 = nn.Dropout(0.5) | |
| self.fc2 = nn.Linear(256,128) | |
| self.dropout2 = nn.Dropout(0.5) | |
| self.fc3 = nn.Linear(128,43) | |
| def forward(self,input): | |
| x = F.relu(self.conv1(input)) | |
| x = F.relu(self.conv2(x)) | |
| x = self.maxpool1(x) | |
| x = F.relu(self.conv3(x)) | |
| x = self.maxpool2(x) | |
| x = torch.flatten(x,1) | |
| x = F.relu(self.fc1(x)) | |
| x = self.dropout1(x) | |
| x = F.relu(self.fc2(x)) | |
| x = self.dropout2(x) | |
| output = self.fc3(x) | |
| return output | |
| def load_model(): | |
| model = ClassNet() | |
| try: | |
| state_dict = torch.load('traffic_light_model_weights.pth', map_location=torch.device('cpu')) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| return None | |
| def load_data(): | |
| y_test = pd.read_csv('traffic_lights/Test.csv') | |
| imgs = y_test["Path"].values | |
| labels = y_test["ClassId"].values | |
| # st.write(imgs) | |
| test_images = [] | |
| for img in imgs: | |
| if isinstance(img,str): | |
| image = Image.open('traffic_lights/'+img) | |
| image = image.resize([30, 30]) | |
| test_images.append(np.array(image)) | |
| # Load meta images | |
| meta_images = {} | |
| meta_folder = 'traffic_lights/Meta' # Replace with the path to your meta folder | |
| for class_id in range(43): | |
| meta_image_path = os.path.join(meta_folder, f"{class_id}.png") # Assuming meta images are named as 0.png, 1.png, etc. | |
| if os.path.exists(meta_image_path): | |
| meta_images[class_id] = Image.open(meta_image_path) | |
| return test_images, labels, meta_images | |
| def main(): | |
| colored_title("Traffic Symbol Prediction", "black") | |
| # Load data | |
| test_images, labels, meta_images = load_data() | |
| # Display test images for selection | |
| colored_subheader("Select an Image for Prediction:", "black") | |
| selected_index = st.selectbox("Select an image by index:", options=range(len(test_images)), index=0) | |
| # Display the selected test image | |
| st.image(test_images[selected_index], width=150) | |
| st.markdown( | |
| f'<p class="custom-caption">Selected Test Image (Class: {labels[selected_index]})</p>', | |
| unsafe_allow_html=True | |
| ) | |
| # Predict button | |
| if st.button("Predict"): | |
| model = load_model() | |
| if model is not None: | |
| # Preprocess the selected image | |
| image = test_images[selected_index] / 255.0 # Normalize | |
| image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) # Convert to tensor | |
| # Make prediction | |
| with torch.no_grad(): | |
| output = model(image) | |
| predicted_class = torch.argmax(output, dim=1).item() | |
| # Display prediction result | |
| colored_subheader("Prediction Results:", "green") | |
| colored_text(f"Predicted Class: {predicted_class}", "green") | |
| # Display the corresponding meta image | |
| if predicted_class in meta_images: | |
| st.image(meta_images[predicted_class], width=150) | |
| st.markdown( | |
| f'<p class="custom-caption">Clear Image for Class: {predicted_class}</p>', | |
| unsafe_allow_html=True | |
| ) | |
| else: | |
| st.warning(f"No clear image found for class {predicted_class} in the meta folder.") | |
| if __name__ == "__main__": | |
| main() |