Spaces:
Build error
Build error
| import os | |
| import hopsworks | |
| import geopandas as gpd | |
| from shapely.wkt import loads as wkt_loads | |
| import folium | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import matplotlib.colors as mcolors | |
| import numpy as np | |
| import branca | |
| # Hardcoded feature group name | |
| FEATURE_GROUP_NAME = "rs_predictions" | |
| min_map = 0.0 | |
| max_map = 1.0 | |
| # Connect to Hopsworks Project and Feature Store | |
| def connect_hopsworks(): | |
| # Retrieve API key from the environment variable | |
| api_key = os.getenv("HOPSWORKS_API_KEY") | |
| if not api_key: | |
| raise ValueError("Hopsworks API key not found. Set the HOPSWORKS_API_KEY environment variable.") | |
| # Login to Hopsworks using the API key | |
| project = hopsworks.login(project="ScalableMLandDeepLcourse") | |
| fs = project.get_feature_store() | |
| return fs | |
| # Fetch data from the feature group | |
| def fetch_linestring_data(): | |
| fs = connect_hopsworks() | |
| fg = fs.get_feature_group(name=FEATURE_GROUP_NAME) | |
| data = fg.read() # Read the data into a DataFrame | |
| data["coordinates"] = data["coordinates"].apply(wkt_loads) # Convert WKT strings back to LineString | |
| gdf = gpd.GeoDataFrame(data, geometry="coordinates") # Create a GeoDataFrame | |
| return gdf | |
| # Create a Folium map with the GeoDataFrame | |
| def create_map(gdf, feature = 'predicted_rs'): | |
| m = folium.Map(location=[59.34318, 18.05141], zoom_start=14) # Centered on Stockholm | |
| # Normalize predicted_rs values for colormap | |
| # norm = mcolors.Normalize(vmin=gdf[feature].min(), vmax=gdf[feature].max()) | |
| norm = mcolors.Normalize(vmin=min_map, vmax=max_map) | |
| cmap = plt.cm.RdYlGn | |
| # Add lines to the map | |
| for _, row in gdf.iterrows(): | |
| color = mcolors.to_hex(cmap(norm(row[feature]))) | |
| folium.GeoJson( | |
| row['coordinates'], | |
| style_function=lambda x, color=color: {'color': color} | |
| ).add_to(m) | |
| return m._repr_html_() | |
| # Create a histogram of predicted_rs values and print the array | |
| def create_histogram(gdf): | |
| hist, bin_edges = np.histogram(gdf['predicted_rs'], bins=20) | |
| print("Histogram counts:", hist) | |
| print("Bin edges:", bin_edges) | |
| return hist, bin_edges | |
| def generate_legend_html(cmap, norm): | |
| colormap = branca.colormap.LinearColormap( | |
| colors=[mcolors.to_hex(cmap(norm(v))) for v in [norm.vmin, norm.vmax]], | |
| vmin=norm.vmin, | |
| vmax=norm.vmax, | |
| caption='Relative Traffic Speed (1.0 = Free Flow | 0.0 = Stopped)' | |
| ) | |
| # colormap.caption.style = 'color: white; font-size: 14px; font-weight: bold;' | |
| legend_html = colormap._repr_html_() | |
| return legend_html | |
| # Gradio interface function | |
| def display_maps(): | |
| gdf = fetch_linestring_data() | |
| # print(gdf.head()) | |
| map1 = create_map(gdf, feature='relativespeed') | |
| map2 = create_map(gdf, feature='predicted_rs') | |
| legend_html = generate_legend_html(plt.cm.RdYlGn, mcolors.Normalize(vmin=min_map, vmax=max_map)) | |
| return map1, map2, legend_html | |
| if __name__ == "__main__": | |
| with gr.Blocks() as iface: | |
| with gr.Row(): | |
| gr.Markdown("# Traffic in Stockholm near Odenplan") | |
| generate_button = gr.Button("Generate Maps", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Current traffic speed") | |
| map_output1 = gr.HTML() | |
| with gr.Column(): | |
| gr.Markdown("### Traffic speed in 1 hour") | |
| map_output2 = gr.HTML() | |
| with gr.Row(): | |
| legend_output = gr.HTML() | |
| generate_button.click(display_maps, inputs=[], | |
| outputs=[map_output1, map_output2, legend_output]) | |
| iface.launch() |