Spaces:
Build error
Build error
File size: 3,647 Bytes
295bf7a deaeeaa 9f8218d deaeeaa feb8167 bb8b002 92c2707 9f8218d 295bf7a fa5fe13 e5f156c 295bf7a 9f8218d 295bf7a 94777bd 9f8218d 295bf7a 9f8218d 295bf7a 9f8218d 6c82a20 9f8218d 4a74c6c 1bf7664 591596d 4a74c6c feb8167 e5f156c a2b78bc feb8167 4a74c6c 1bf7664 4a74c6c bb8b002 d195805 92c2707 fa5fe13 92c2707 e349522 2f72399 92c2707 e5f156c d195805 e5f156c 4014435 e5f156c d195805 92c2707 c1265dc 4a74c6c a2b78bc 8314374 c1265dc a2b78bc 3d9d7bd 8314374 3d9d7bd 8314374 3d9d7bd 92c2707 430470f 571626b 1fbf136 4a74c6c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | import os
import hopsworks
import geopandas as gpd
from shapely.wkt import loads as wkt_loads
import folium
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import branca
# Hardcoded feature group name
FEATURE_GROUP_NAME = "rs_predictions"
min_map = 0.0
max_map = 1.0
# Connect to Hopsworks Project and Feature Store
def connect_hopsworks():
# Retrieve API key from the environment variable
api_key = os.getenv("HOPSWORKS_API_KEY")
if not api_key:
raise ValueError("Hopsworks API key not found. Set the HOPSWORKS_API_KEY environment variable.")
# Login to Hopsworks using the API key
project = hopsworks.login(project="ScalableMLandDeepLcourse")
fs = project.get_feature_store()
return fs
# Fetch data from the feature group
def fetch_linestring_data():
fs = connect_hopsworks()
fg = fs.get_feature_group(name=FEATURE_GROUP_NAME)
data = fg.read() # Read the data into a DataFrame
data["coordinates"] = data["coordinates"].apply(wkt_loads) # Convert WKT strings back to LineString
gdf = gpd.GeoDataFrame(data, geometry="coordinates") # Create a GeoDataFrame
return gdf
# Create a Folium map with the GeoDataFrame
def create_map(gdf, feature = 'predicted_rs'):
m = folium.Map(location=[59.34318, 18.05141], zoom_start=14) # Centered on Stockholm
# Normalize predicted_rs values for colormap
# norm = mcolors.Normalize(vmin=gdf[feature].min(), vmax=gdf[feature].max())
norm = mcolors.Normalize(vmin=min_map, vmax=max_map)
cmap = plt.cm.RdYlGn
# Add lines to the map
for _, row in gdf.iterrows():
color = mcolors.to_hex(cmap(norm(row[feature])))
folium.GeoJson(
row['coordinates'],
style_function=lambda x, color=color: {'color': color}
).add_to(m)
return m._repr_html_()
# Create a histogram of predicted_rs values and print the array
def create_histogram(gdf):
hist, bin_edges = np.histogram(gdf['predicted_rs'], bins=20)
print("Histogram counts:", hist)
print("Bin edges:", bin_edges)
return hist, bin_edges
def generate_legend_html(cmap, norm):
colormap = branca.colormap.LinearColormap(
colors=[mcolors.to_hex(cmap(norm(v))) for v in [norm.vmin, norm.vmax]],
vmin=norm.vmin,
vmax=norm.vmax,
caption='Relative Traffic Speed (1.0 = Free Flow | 0.0 = Stopped)'
)
# colormap.caption.style = 'color: white; font-size: 14px; font-weight: bold;'
legend_html = colormap._repr_html_()
return legend_html
# Gradio interface function
def display_maps():
gdf = fetch_linestring_data()
# print(gdf.head())
map1 = create_map(gdf, feature='relativespeed')
map2 = create_map(gdf, feature='predicted_rs')
legend_html = generate_legend_html(plt.cm.RdYlGn, mcolors.Normalize(vmin=min_map, vmax=max_map))
return map1, map2, legend_html
if __name__ == "__main__":
with gr.Blocks() as iface:
with gr.Row():
gr.Markdown("# Traffic in Stockholm near Odenplan")
generate_button = gr.Button("Generate Maps", variant="primary")
with gr.Row():
with gr.Column():
gr.Markdown("### Current traffic speed")
map_output1 = gr.HTML()
with gr.Column():
gr.Markdown("### Traffic speed in 1 hour")
map_output2 = gr.HTML()
with gr.Row():
legend_output = gr.HTML()
generate_button.click(display_maps, inputs=[],
outputs=[map_output1, map_output2, legend_output])
iface.launch() |