| from io import BytesIO |
|
|
| import numpy as np |
| import pandas as pd |
| import plotly.express as px |
| import streamlit as st |
| from hilbertcurve.hilbertcurve import HilbertCurve |
| from sklearn.cluster import KMeans |
|
|
|
|
| def cluster_sites_hilbert_curve_same_size( |
| df: pd.DataFrame, |
| lat_col: str, |
| lon_col: str, |
| region_col: str, |
| max_sites: int = 25, |
| mix_regions: bool = False, |
| ): |
| clusters = [] |
| cluster_id = 0 |
|
|
| if not mix_regions: |
| grouped = df.groupby(region_col) |
| else: |
| grouped = [("All", df)] |
|
|
| |
| p = 16 |
| hilbert_curve = HilbertCurve(p, 2) |
|
|
| for region, group in grouped: |
| if len(group) == 0: |
| continue |
|
|
| |
| lat_min, lat_max = group[lat_col].min(), group[lat_col].max() |
| lon_min, lon_max = group[lon_col].min(), group[lon_col].max() |
|
|
| group = group.copy() |
| group["x"] = ((group[lat_col] - lat_min) / (lat_max - lat_min + 1e-10)) * ( |
| 2**p - 1 |
| ) |
| group["y"] = ((group[lon_col] - lon_min) / (lon_max - lon_min + 1e-10)) * ( |
| 2**p - 1 |
| ) |
|
|
| |
| group["hilbert"] = group.apply( |
| lambda row: hilbert_curve.distance_from_point( |
| [int(row["x"]), int(row["y"])] |
| ), |
| axis=1, |
| ) |
|
|
| |
| group = group.sort_values("hilbert") |
|
|
| |
| for i in range(0, len(group), max_sites): |
| cluster = group.iloc[i : i + max_sites].copy() |
| cluster["Cluster"] = f"C{cluster_id}" |
| clusters.append(cluster) |
| cluster_id += 1 |
|
|
| result = pd.concat(clusters) |
| return result.drop(columns=["x", "y", "hilbert"], errors="ignore") |
|
|
|
|
| def cluster_sites_kmeans_lower_to_fixed_size( |
| df: pd.DataFrame, |
| lat_col: str, |
| lon_col: str, |
| region_col: str, |
| max_sites: int = 25, |
| mix_regions: bool = False, |
| ): |
| clusters = [] |
| cluster_id = 0 |
|
|
| if not mix_regions: |
| grouped = df.groupby(region_col) |
| else: |
| grouped = [("All", df)] |
|
|
| for region, group in grouped: |
| coords = group[[lat_col, lon_col]].to_numpy() |
| remaining_sites = group.copy() |
|
|
| while len(remaining_sites) > 0: |
| |
| n_remaining = len(remaining_sites) |
| n_clusters = max(1, int(np.ceil(n_remaining / max_sites))) |
|
|
| if n_remaining <= max_sites: |
| |
| cluster_group = remaining_sites.copy() |
| cluster_group["Cluster"] = f"C{cluster_id}" |
| clusters.append(cluster_group) |
| cluster_id += 1 |
| break |
| else: |
| |
| kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) |
| labels = kmeans.fit_predict( |
| remaining_sites[[lat_col, lon_col]].to_numpy() |
| ) |
|
|
| |
| temp_df = remaining_sites.copy() |
| temp_df["Cluster"] = labels |
| temp_df["Temp_Cluster"] = labels |
|
|
| for cluster_num in range(n_clusters): |
| cluster_group = temp_df[temp_df["Temp_Cluster"] == cluster_num] |
| if len(cluster_group) <= max_sites: |
| |
| cluster_group = cluster_group.drop(columns=["Temp_Cluster"]) |
| cluster_group["Cluster"] = f"C{cluster_id}" |
| clusters.append(cluster_group) |
| cluster_id += 1 |
| |
| remaining_sites = remaining_sites.drop(cluster_group.index) |
| |
|
|
| return pd.concat(clusters) |
|
|
|
|
| def to_excel(df: pd.DataFrame) -> bytes: |
| output = BytesIO() |
| with pd.ExcelWriter(output, engine="xlsxwriter") as writer: |
| df.to_excel(writer, index=False, sheet_name="Clusters") |
| return output.getvalue() |
|
|
|
|
| st.title("Automatic Site Clustering App") |
|
|
| |
| st.write( |
| """This app allows you to cluster sites based on their latitude and longitude. |
| **Please choose a file containing the latitude and longitude region and site code columns.** |
| """ |
| ) |
|
|
| |
| clustering_sample_file_path = "samples/Site_Clustering.xlsx" |
|
|
| |
| st.download_button( |
| label="Download Clustering Sample File", |
| data=open(clustering_sample_file_path, "rb").read(), |
| file_name="Site_Clustering.xlsx", |
| mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", |
| ) |
|
|
| uploaded_file = st.file_uploader("Upload your Excel file ", type=["xlsx"]) |
|
|
| if uploaded_file: |
| df = pd.read_excel(uploaded_file) |
| st.write("Sample of uploaded data:", df.head()) |
|
|
| columns = df.columns.tolist() |
|
|
| with st.form("clustering_form"): |
| lat_col = st.selectbox("Select Latitude column", columns) |
| lon_col = st.selectbox("Select Longitude column", columns) |
| region_col = st.selectbox("Select Region column", columns) |
| code_col = st.selectbox("Select Site Code column", columns) |
| max_sites = st.number_input( |
| "Max sites per cluster", min_value=5, max_value=100, value=25 |
| ) |
| cluster_method = st.selectbox( |
| "Select clustering method", |
| [ |
| "Uniform number of sites for each cluster", |
| "Number of sites Lower than max but not uniform", |
| ], |
| ) |
| mix_regions = st.checkbox( |
| "Allow mixing different regions in clusters", value=False |
| ) |
| submitted = st.form_submit_button("Run Clustering") |
|
|
| if submitted: |
| if cluster_method == "Uniform number of sites for each cluster": |
| clustered_df = cluster_sites_hilbert_curve_same_size( |
| df, lat_col, lon_col, region_col, max_sites, mix_regions |
| ) |
| elif cluster_method == "Number of sites Lower than max but not uniform": |
| clustered_df = cluster_sites_kmeans_lower_to_fixed_size( |
| df, lat_col, lon_col, region_col, max_sites, mix_regions |
| ) |
| st.success("Clustering completed!") |
|
|
| |
| cluster_size = clustered_df["Cluster"].value_counts().sort_index() |
| fig = px.bar(cluster_size, x=cluster_size.index, y=cluster_size.values) |
| fig.update_layout(title="Cluster Size") |
| st.plotly_chart(fig) |
|
|
| |
| cluster_size_per_region = ( |
| clustered_df.groupby([region_col, "Cluster"]) |
| .size() |
| .reset_index(name="count") |
| ) |
| fig = px.bar(cluster_size_per_region, x="Cluster", y="count", color=region_col) |
| fig.update_layout(title="Cluster Size per Region") |
| st.plotly_chart(fig) |
|
|
| |
| clustered_df["size"] = 10 |
| fig = px.scatter_map( |
| clustered_df, |
| lat=lat_col, |
| lon=lon_col, |
| color="Cluster", |
| size="size", |
| hover_name=code_col, |
| hover_data=[region_col], |
| zoom=5, |
| height=600, |
| ) |
| fig.update_layout(mapbox_style="open-street-map") |
| fig.update_traces(marker=dict(size=15)) |
| st.plotly_chart(fig) |
|
|
| |
| st.download_button( |
| label="Download clustered Excel file", |
| data=to_excel(clustered_df), |
| file_name="clustered_sites.xlsx", |
| mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", |
| on_click="ignore", |
| type="primary", |
| ) |
|
|