cidm-a / app.py
vlntnstarodub's picture
Update app.py
93bf3e4
raw
history blame
7.39 kB
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
from PIL import Image
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from geopy.distance import geodesic
import json
import csv
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_optimizer import Lookahead
from torchvision import transforms
class GSNet(nn.Module):
def __init__(self, nclasses):
super(GSNet, self).__init__()
self.conv1 = nn.Conv2d(3, 100, 7)
self.conv1_bn = nn.BatchNorm2d(100)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(100, 150, 5)
self.conv2_bn = nn.BatchNorm2d(150)
self.conv3 = nn.Conv2d(150, 250, 3)
self.conv3_bn = nn.BatchNorm2d(250)
self.fc1 = nn.Linear(250 * 3 * 3, 350)
self.fc1_bn = nn.BatchNorm1d(350)
self.fc2 = nn.Linear(350, 400)
self.fc2_bn = nn.BatchNorm1d(400)
self.fc3 = nn.Linear(350, 450)
self.fc3_bn = nn.BatchNorm1d(450)
self.fc4 = nn.Linear(450, nclasses)
self.dropout = nn.Dropout(p=0.6)
def forward(self, x):
# convolutional layer
x = self.pool(F.elu(self.conv1(x)))
x = self.dropout(self.conv1_bn(x))
# convolutional layer
x = self.pool(F.elu(self.conv2(x)))
x = self.dropout(self.conv2_bn(x))
# convolutional layer
x = self.pool(F.elu(self.conv3(x)))
x = self.dropout(self.conv3_bn(x))
#fully connected layer
x = x.view(x.size(0), -1)
x = F.elu(self.fc1(x))
#fully connected layer
x = self.dropout(self.fc1_bn(x))
x = self.fc3(x)
return x
nclasses = 9
my_pretrained_model = GSNet(nclasses)
model_path = "CI_model.pth"
my_pretrained_model.load_state_dict(torch.load(model_path))
my_pretrained_model.eval()
# Read data
pd.set_option('display.max_columns', None)
df = pd.read_json("./Tourist_related.json")
# Convert data to CSV
df.to_csv('./data.csv', index=False)
# Read CSV data
read_dat = pd.read_csv('./data.csv', on_bad_lines='skip')
# Data Quality Test
def data_quality_test():
st.subheader("Data Quality Test")
for feature, dtype in read_dat.dtypes.items():
st.write(f'{feature}: {dtype}')
check_na = read_dat.isnull().sum()
st.write("Null Counts:")
st.write(check_na)
# Correlation Analysis
def correlation_analysis():
st.subheader("Correlation Analysis")
plt.figure(figsize=(12, 6))
sns.heatmap(read_dat.corr(), annot=True)
st.pyplot()
# Geo Plot Visualization
def geo_plot_visualization():
st.subheader("Geo Plot Visualization")
df['lat'] = df['location'].apply(lambda x: x.get('lat', None))
df['lng'] = df['location'].apply(lambda x: x.get('lng', None))
color_scale = [(0, 'orange'), (1, 'red')]
fig = px.scatter_mapbox(df,
lat="lat",
lon="lng",
color_continuous_scale=color_scale,
zoom=8,
height=800,
width=800)
fig.update_layout(mapbox_style="open-street-map")
fig.update_layout(margin={"r": 0, "t": 0, "l": 0, "b": 0})
st.plotly_chart(fig)
# Geo Plot with ML Recommendation
def geo_plot_ml_recommendation():
st.subheader("Geo Plot with ML Recommendation")
le = LabelEncoder()
read_dat['encoded_category'] = le.fit_transform(read_dat['categoryName'])
# Upload Images
uploaded_files = st.file_uploader("Choose 4 images", type="jpg", accept_multiple_files=True)
swiped_images = pd.DataFrame({'image_category': []})
if uploaded_files:
for file in uploaded_files:
img = Image.open(file)
st.image(img, caption="Uploaded Image", use_column_width=True)
# Image Classification using Pre-trained Model
img_array = np.array(img)
img_tensor = transform(Image.fromarray(img_array)).unsqueeze(0)
img_tensor = img_tensor.to(device)
with torch.no_grad():
my_pretrained_model.eval()
logits = my_pretrained_model(img_tensor)
predicted_category = label_encoder.classes_[logits.argmax()]
swiped_images = swiped_images.append({'image_category': predicted_category}, ignore_index=True)
# Enter Current Location
current_latitude = st.number_input("Enter current latitude", value=54.898291)
current_longitude = st.number_input("Enter current longitude", value=23.924751)
# Calculate ML Recommendation
X = read_dat[['encoded_category']]
y = read_dat['encoded_category']
classifier = RandomForestClassifier()
classifier.fit(X, y)
swiped_images['encoded_category'] = le.transform(swiped_images['image_category'])
majority_category = classifier.predict(swiped_images[['encoded_category']])
current_location = (current_latitude, current_longitude)
recommended_locations = read_dat[read_dat['encoded_category'] == majority_category[0]][['latitude', 'longitude']]
recommended_locations['distance'] = recommended_locations.apply(
lambda row: geodesic(current_location, (row['latitude'], row['longitude'])).kilometers, axis=1
)
closest_location = recommended_locations.loc[recommended_locations['distance'].idxmin()]
# Plot Geo Plot with ML Recommendation
fig = px.scatter_mapbox(recommended_locations,
lat='latitude',
lon='longitude',
hover_data=['distance'],
mapbox_style="open-street-map",
zoom=10,
size_max=20)
fig.add_trace(px.scatter_mapbox(
pd.DataFrame({'latitude': [current_location[0]], 'longitude': [current_location[1]], 'distance': [0]}),
lat='latitude',
lon='longitude',
hover_data=['distance'],
size_max=20,
).data[0])
fig.add_trace(px.line_mapbox(
lat=[current_location[0], closest_location['latitude']],
lon=[current_location[1], closest_location['longitude']],
color_discrete_sequence=['red'],
).data[0])
st.plotly_chart(fig)
st.write("Closest Recommended Location:")
st.write(closest_location[['latitude', 'longitude']])
st.write("Distance to Closest Location:", closest_location['distance'], "kilometers")
# Streamlit App
def main():
st.title("CIDM - TEAM A")
# Sidebar
st.sidebar.header("Options")
option = st.sidebar.selectbox("Select an option", ["Home", "Data Quality Test", "Correlation Analysis", "Geo Plot Visualization", "ML Recommendation Geo Plot"])
# Display selected option
if option == "Home":
st.write("# Welcome to CIDM TEAM A - GROUP Project")
st.write("Contributors: Mantvydas, Valentyna, Ali Al Ahmad, Ravinthiran")
elif option == "Data Quality Test":
data_quality_test()
elif option == "Correlation Analysis":
correlation_analysis()
elif option == "Geo Plot Visualization":
geo_plot_visualization()
elif option == "ML Recommendation Geo Plot":
geo_plot_ml_recommendation()
if __name__ == '__main__':
main()