YAMITEK's picture
Update app.py
91241f5 verified
import streamlit as st
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
from PIL import Image
from sklearn.preprocessing import StandardScaler, LabelEncoder
st.set_page_config(layout="centered")
# Add custom CSS for background image and styling
# Add custom CSS for background image and styling
st.markdown("""
<style>
.stApp {
background-image: url("https://as1.ftcdn.net/jpg/01/82/21/76/1000_F_182217694_DZi3Ytqsb0RpWQb9dwC7NLFwkwqgnh0r.jpg");
background-size: cover;
background-position: center;
background-repeat: no-repeat;
height: auto; /* Allows the page to expand for scrolling */
overflow: auto; /* Enables scrolling if the page content overflows */
# position : relative
}
/* Adjust opacity of overlay to make content more visible */
.stApp::before {
content: "";
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
background-color: rgba(255, 255, 255, 0.8); /* Slightly higher opacity */
z-index: -1;
}
/* Ensure content appears above the overlay */
.stApp > * {
position: relative;
z-index: 2;
}
/* Ensure the dataframe is visible */
.dataframe {
background-color: rgba(255, 255, 255, 0.9) !important;
z-index: 3;
}
/* Style text elements for better visibility */
h1, h3, span, div {
text-shadow: 1px 1px 2px rgba(255, 255, 255, 0.2);
}
/* Custom CSS for select box heading */
div.stSelectbox > label {
color: #000000 !important; /* Change to your desired color */
# background-color: black !important; /* Background color of the dropdown */
font-size: 24px !important; /* Change font size */
font-weight: bold !important; /* Make text bold */
}
/* Custom CSS for image caption */
.custom-caption {
color: #000000 !important; /* Change to your desired color */
font-size: 24px !important; /* Optional: Change font size */
text-align: center; /* Center-align the caption */
}
.stMainBlockContainer {
background-color: white !important; /* Background color of the dropdown */
}
</style>
""", unsafe_allow_html=True)
# Custom title styling functions
def colored_title(text, color):
st.markdown(f"<h1 style='color: {color};'>{text}</h1>", unsafe_allow_html=True)
def colored_subheader(text, color):
st.markdown(f"<h3 style='color: {color};'>{text}</h3>", unsafe_allow_html=True)
def colored_text(text, color):
st.markdown(f"<span style='color: {color};'>{text}</span>", unsafe_allow_html=True)
class ClassNet(nn.Module):
def __init__(self):
super(ClassNet, self).__init__()
self.conv1 = nn.Conv2d(3,6,3)
self.conv2 = nn.Conv2d(6,16,5)
self.maxpool1 = nn.MaxPool2d(2)
self.conv3 = nn.Conv2d(16,32,5)
self.maxpool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(512,256)
self.dropout1 = nn.Dropout(0.5)
self.fc2 = nn.Linear(256,128)
self.dropout2 = nn.Dropout(0.5)
self.fc3 = nn.Linear(128,43)
def forward(self,input):
x = F.relu(self.conv1(input))
x = F.relu(self.conv2(x))
x = self.maxpool1(x)
x = F.relu(self.conv3(x))
x = self.maxpool2(x)
x = torch.flatten(x,1)
x = F.relu(self.fc1(x))
x = self.dropout1(x)
x = F.relu(self.fc2(x))
x = self.dropout2(x)
output = self.fc3(x)
return output
@st.cache_resource
def load_model():
model = ClassNet()
try:
state_dict = torch.load('traffic_light_model_weights.pth', map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()
return model
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None
@st.cache_data
def load_data():
y_test = pd.read_csv('traffic_lights/Test.csv')
imgs = y_test["Path"].values
labels = y_test["ClassId"].values
# st.write(imgs)
test_images = []
for img in imgs:
if isinstance(img,str):
image = Image.open('traffic_lights/'+img)
image = image.resize([30, 30])
test_images.append(np.array(image))
# Load meta images
meta_images = {}
meta_folder = 'traffic_lights/Meta' # Replace with the path to your meta folder
for class_id in range(43):
meta_image_path = os.path.join(meta_folder, f"{class_id}.png") # Assuming meta images are named as 0.png, 1.png, etc.
if os.path.exists(meta_image_path):
meta_images[class_id] = Image.open(meta_image_path)
return test_images, labels, meta_images
def main():
colored_title("Traffic Symbol Prediction", "black")
# Load data
test_images, labels, meta_images = load_data()
# Display test images for selection
colored_subheader("Select an Image for Prediction:", "black")
selected_index = st.selectbox("Select an image by index:", options=range(len(test_images)), index=0)
# Display the selected test image
st.image(test_images[selected_index], width=150)
st.markdown(
f'<p class="custom-caption">Selected Test Image (Class: {labels[selected_index]})</p>',
unsafe_allow_html=True
)
# Predict button
if st.button("Predict"):
model = load_model()
if model is not None:
# Preprocess the selected image
image = test_images[selected_index] / 255.0 # Normalize
image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) # Convert to tensor
# Make prediction
with torch.no_grad():
output = model(image)
predicted_class = torch.argmax(output, dim=1).item()
# Display prediction result
colored_subheader("Prediction Results:", "green")
colored_text(f"Predicted Class: {predicted_class}", "green")
# Display the corresponding meta image
if predicted_class in meta_images:
st.image(meta_images[predicted_class], width=150)
st.markdown(
f'<p class="custom-caption">Clear Image for Class: {predicted_class}</p>',
unsafe_allow_html=True
)
else:
st.warning(f"No clear image found for class {predicted_class} in the meta folder.")
if __name__ == "__main__":
main()