IST199655
Update app.py
e349522
import os
import hopsworks
import geopandas as gpd
from shapely.wkt import loads as wkt_loads
import folium
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import branca
# Hardcoded feature group name
FEATURE_GROUP_NAME = "rs_predictions"
min_map = 0.0
max_map = 1.0
# Connect to Hopsworks Project and Feature Store
def connect_hopsworks():
# Retrieve API key from the environment variable
api_key = os.getenv("HOPSWORKS_API_KEY")
if not api_key:
raise ValueError("Hopsworks API key not found. Set the HOPSWORKS_API_KEY environment variable.")
# Login to Hopsworks using the API key
project = hopsworks.login(project="ScalableMLandDeepLcourse")
fs = project.get_feature_store()
return fs
# Fetch data from the feature group
def fetch_linestring_data():
fs = connect_hopsworks()
fg = fs.get_feature_group(name=FEATURE_GROUP_NAME)
data = fg.read() # Read the data into a DataFrame
data["coordinates"] = data["coordinates"].apply(wkt_loads) # Convert WKT strings back to LineString
gdf = gpd.GeoDataFrame(data, geometry="coordinates") # Create a GeoDataFrame
return gdf
# Create a Folium map with the GeoDataFrame
def create_map(gdf, feature = 'predicted_rs'):
m = folium.Map(location=[59.34318, 18.05141], zoom_start=14) # Centered on Stockholm
# Normalize predicted_rs values for colormap
# norm = mcolors.Normalize(vmin=gdf[feature].min(), vmax=gdf[feature].max())
norm = mcolors.Normalize(vmin=min_map, vmax=max_map)
cmap = plt.cm.RdYlGn
# Add lines to the map
for _, row in gdf.iterrows():
color = mcolors.to_hex(cmap(norm(row[feature])))
folium.GeoJson(
row['coordinates'],
style_function=lambda x, color=color: {'color': color}
).add_to(m)
return m._repr_html_()
# Create a histogram of predicted_rs values and print the array
def create_histogram(gdf):
hist, bin_edges = np.histogram(gdf['predicted_rs'], bins=20)
print("Histogram counts:", hist)
print("Bin edges:", bin_edges)
return hist, bin_edges
def generate_legend_html(cmap, norm):
colormap = branca.colormap.LinearColormap(
colors=[mcolors.to_hex(cmap(norm(v))) for v in [norm.vmin, norm.vmax]],
vmin=norm.vmin,
vmax=norm.vmax,
caption='Relative Traffic Speed (1.0 = Free Flow | 0.0 = Stopped)'
)
# colormap.caption.style = 'color: white; font-size: 14px; font-weight: bold;'
legend_html = colormap._repr_html_()
return legend_html
# Gradio interface function
def display_maps():
gdf = fetch_linestring_data()
# print(gdf.head())
map1 = create_map(gdf, feature='relativespeed')
map2 = create_map(gdf, feature='predicted_rs')
legend_html = generate_legend_html(plt.cm.RdYlGn, mcolors.Normalize(vmin=min_map, vmax=max_map))
return map1, map2, legend_html
if __name__ == "__main__":
with gr.Blocks() as iface:
with gr.Row():
gr.Markdown("# Traffic in Stockholm near Odenplan")
generate_button = gr.Button("Generate Maps", variant="primary")
with gr.Row():
with gr.Column():
gr.Markdown("### Current traffic speed")
map_output1 = gr.HTML()
with gr.Column():
gr.Markdown("### Traffic speed in 1 hour")
map_output2 = gr.HTML()
with gr.Row():
legend_output = gr.HTML()
generate_button.click(display_maps, inputs=[],
outputs=[map_output1, map_output2, legend_output])
iface.launch()