import os import hopsworks import geopandas as gpd from shapely.wkt import loads as wkt_loads import folium import gradio as gr import matplotlib.pyplot as plt import matplotlib.colors as mcolors import numpy as np import branca # Hardcoded feature group name FEATURE_GROUP_NAME = "rs_predictions" min_map = 0.0 max_map = 1.0 # Connect to Hopsworks Project and Feature Store def connect_hopsworks(): # Retrieve API key from the environment variable api_key = os.getenv("HOPSWORKS_API_KEY") if not api_key: raise ValueError("Hopsworks API key not found. Set the HOPSWORKS_API_KEY environment variable.") # Login to Hopsworks using the API key project = hopsworks.login(project="ScalableMLandDeepLcourse") fs = project.get_feature_store() return fs # Fetch data from the feature group def fetch_linestring_data(): fs = connect_hopsworks() fg = fs.get_feature_group(name=FEATURE_GROUP_NAME) data = fg.read() # Read the data into a DataFrame data["coordinates"] = data["coordinates"].apply(wkt_loads) # Convert WKT strings back to LineString gdf = gpd.GeoDataFrame(data, geometry="coordinates") # Create a GeoDataFrame return gdf # Create a Folium map with the GeoDataFrame def create_map(gdf, feature = 'predicted_rs'): m = folium.Map(location=[59.34318, 18.05141], zoom_start=14) # Centered on Stockholm # Normalize predicted_rs values for colormap # norm = mcolors.Normalize(vmin=gdf[feature].min(), vmax=gdf[feature].max()) norm = mcolors.Normalize(vmin=min_map, vmax=max_map) cmap = plt.cm.RdYlGn # Add lines to the map for _, row in gdf.iterrows(): color = mcolors.to_hex(cmap(norm(row[feature]))) folium.GeoJson( row['coordinates'], style_function=lambda x, color=color: {'color': color} ).add_to(m) return m._repr_html_() # Create a histogram of predicted_rs values and print the array def create_histogram(gdf): hist, bin_edges = np.histogram(gdf['predicted_rs'], bins=20) print("Histogram counts:", hist) print("Bin edges:", bin_edges) return hist, bin_edges def generate_legend_html(cmap, norm): colormap = branca.colormap.LinearColormap( colors=[mcolors.to_hex(cmap(norm(v))) for v in [norm.vmin, norm.vmax]], vmin=norm.vmin, vmax=norm.vmax, caption='Relative Traffic Speed (1.0 = Free Flow | 0.0 = Stopped)' ) # colormap.caption.style = 'color: white; font-size: 14px; font-weight: bold;' legend_html = colormap._repr_html_() return legend_html # Gradio interface function def display_maps(): gdf = fetch_linestring_data() # print(gdf.head()) map1 = create_map(gdf, feature='relativespeed') map2 = create_map(gdf, feature='predicted_rs') legend_html = generate_legend_html(plt.cm.RdYlGn, mcolors.Normalize(vmin=min_map, vmax=max_map)) return map1, map2, legend_html if __name__ == "__main__": with gr.Blocks() as iface: with gr.Row(): gr.Markdown("# Traffic in Stockholm near Odenplan") generate_button = gr.Button("Generate Maps", variant="primary") with gr.Row(): with gr.Column(): gr.Markdown("### Current traffic speed") map_output1 = gr.HTML() with gr.Column(): gr.Markdown("### Traffic speed in 1 hour") map_output2 = gr.HTML() with gr.Row(): legend_output = gr.HTML() generate_button.click(display_maps, inputs=[], outputs=[map_output1, map_output2, legend_output]) iface.launch()