File size: 3,647 Bytes
295bf7a
deaeeaa
9f8218d
deaeeaa
 
 
 
feb8167
bb8b002
92c2707
9f8218d
295bf7a
 
fa5fe13
e5f156c
295bf7a
9f8218d
 
295bf7a
 
 
 
 
 
94777bd
9f8218d
 
 
 
295bf7a
9f8218d
295bf7a
9f8218d
6c82a20
 
9f8218d
 
4a74c6c
1bf7664
591596d
4a74c6c
feb8167
e5f156c
 
a2b78bc
feb8167
4a74c6c
 
1bf7664
4a74c6c
 
 
 
 
 
 
bb8b002
 
 
 
 
 
 
d195805
92c2707
 
 
 
fa5fe13
92c2707
e349522
2f72399
 
92c2707
e5f156c
d195805
e5f156c
4014435
e5f156c
 
d195805
92c2707
c1265dc
4a74c6c
a2b78bc
 
8314374
c1265dc
a2b78bc
3d9d7bd
8314374
3d9d7bd
 
8314374
3d9d7bd
92c2707
 
430470f
571626b
1fbf136
4a74c6c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import hopsworks
import geopandas as gpd
from shapely.wkt import loads as wkt_loads
import folium
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import branca

# Hardcoded feature group name
FEATURE_GROUP_NAME = "rs_predictions"
min_map = 0.0
max_map = 1.0

# Connect to Hopsworks Project and Feature Store
def connect_hopsworks():
    # Retrieve API key from the environment variable
    api_key = os.getenv("HOPSWORKS_API_KEY")
    if not api_key:
        raise ValueError("Hopsworks API key not found. Set the HOPSWORKS_API_KEY environment variable.")
    
    # Login to Hopsworks using the API key
    project = hopsworks.login(project="ScalableMLandDeepLcourse")
    fs = project.get_feature_store()
    return fs

# Fetch data from the feature group
def fetch_linestring_data():
    fs = connect_hopsworks()
    fg = fs.get_feature_group(name=FEATURE_GROUP_NAME)
    data = fg.read()  # Read the data into a DataFrame
    data["coordinates"] = data["coordinates"].apply(wkt_loads)  # Convert WKT strings back to LineString
    gdf = gpd.GeoDataFrame(data, geometry="coordinates")  # Create a GeoDataFrame
    return gdf

# Create a Folium map with the GeoDataFrame
def create_map(gdf, feature = 'predicted_rs'):
    m = folium.Map(location=[59.34318, 18.05141], zoom_start=14)  # Centered on Stockholm

    # Normalize predicted_rs values for colormap
    # norm = mcolors.Normalize(vmin=gdf[feature].min(), vmax=gdf[feature].max())
    norm = mcolors.Normalize(vmin=min_map, vmax=max_map)
    cmap = plt.cm.RdYlGn

    # Add lines to the map
    for _, row in gdf.iterrows():
        color = mcolors.to_hex(cmap(norm(row[feature])))
        folium.GeoJson(
            row['coordinates'],
            style_function=lambda x, color=color: {'color': color}
        ).add_to(m)
    
    return m._repr_html_()

# Create a histogram of predicted_rs values and print the array
def create_histogram(gdf):
    hist, bin_edges = np.histogram(gdf['predicted_rs'], bins=20)
    print("Histogram counts:", hist)
    print("Bin edges:", bin_edges)
    return hist, bin_edges

def generate_legend_html(cmap, norm):
    colormap = branca.colormap.LinearColormap(
        colors=[mcolors.to_hex(cmap(norm(v))) for v in [norm.vmin, norm.vmax]],
        vmin=norm.vmin,
        vmax=norm.vmax,
        caption='Relative Traffic Speed (1.0 = Free Flow | 0.0 = Stopped)'
    )
    # colormap.caption.style = 'color: white; font-size: 14px; font-weight: bold;'
    legend_html = colormap._repr_html_()
    return legend_html

# Gradio interface function
def display_maps():
    gdf = fetch_linestring_data()
    # print(gdf.head())
    map1 = create_map(gdf, feature='relativespeed')
    map2 = create_map(gdf, feature='predicted_rs')
    legend_html = generate_legend_html(plt.cm.RdYlGn, mcolors.Normalize(vmin=min_map, vmax=max_map))
    return map1, map2, legend_html

if __name__ == "__main__":
    with gr.Blocks() as iface:
        with gr.Row():
            gr.Markdown("# Traffic in Stockholm near Odenplan")
            generate_button = gr.Button("Generate Maps", variant="primary")
        with gr.Row():
            with gr.Column():
                gr.Markdown("### Current traffic speed")
                map_output1 = gr.HTML()
            with gr.Column():
                gr.Markdown("### Traffic speed in 1 hour")
                map_output2 = gr.HTML()
        with gr.Row():
            legend_output = gr.HTML()

        generate_button.click(display_maps, inputs=[],
                              outputs=[map_output1, map_output2, legend_output])
    iface.launch()