| from io import BytesIO |
|
|
| import numpy as np |
| import pandas as pd |
| import plotly.express as px |
| import streamlit as st |
| from sklearn.cluster import KMeans |
|
|
|
|
| def cluster_sites( |
| df: pd.DataFrame, |
| lat_col: str, |
| lon_col: str, |
| region_col: str, |
| max_sites: int = 25, |
| mix_regions: bool = False, |
| ): |
| clusters = [] |
| cluster_id = 0 |
|
|
| if not mix_regions: |
| grouped = df.groupby(region_col) |
| else: |
| grouped = [("All", df)] |
|
|
| for region, group in grouped: |
| coords = group[[lat_col, lon_col]].to_numpy() |
| n_clusters = max(1, int(np.ceil(len(group) / max_sites))) |
|
|
| if len(group) < max_sites: |
| labels = np.zeros(len(group), dtype=int) |
| else: |
| kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) |
| labels = kmeans.fit_predict(coords) |
|
|
| group = group.copy() |
| group["Cluster"] = [f"C{cluster_id + l}" for l in labels] |
| clusters.append(group) |
| cluster_id += len(set(labels)) |
|
|
| return pd.concat(clusters) |
|
|
|
|
| def to_excel(df: pd.DataFrame) -> bytes: |
| output = BytesIO() |
| with pd.ExcelWriter(output, engine="xlsxwriter") as writer: |
| df.to_excel(writer, index=False, sheet_name="Clusters") |
| return output.getvalue() |
|
|
|
|
| st.title("Automatic Site Clustering App") |
|
|
| |
| st.write( |
| """This app allows you to cluster sites based on their latitude and longitude. |
| **Please choose a file containing the latitude and longitude region and site code columns.** |
| """ |
| ) |
|
|
| uploaded_file = st.file_uploader("Upload your Excel file", type=["xlsx"]) |
|
|
| if uploaded_file: |
| df = pd.read_excel(uploaded_file) |
| st.write("Sample of uploaded data:", df.head()) |
|
|
| columns = df.columns.tolist() |
|
|
| with st.form("clustering_form"): |
| lat_col = st.selectbox("Select Latitude column", columns) |
| lon_col = st.selectbox("Select Longitude column", columns) |
| region_col = st.selectbox("Select Region column", columns) |
| code_col = st.selectbox("Select Site Code column", columns) |
| max_sites = st.number_input( |
| "Max sites per cluster", min_value=5, max_value=100, value=25 |
| ) |
| mix_regions = st.checkbox( |
| "Allow mixing different regions in clusters", value=False |
| ) |
| submitted = st.form_submit_button("Run Clustering") |
|
|
| if submitted: |
| clustered_df = cluster_sites( |
| df, lat_col, lon_col, region_col, max_sites, mix_regions |
| ) |
| st.success("Clustering completed!") |
| st.write(clustered_df.head()) |
|
|
| |
| fig = px.scatter_map( |
| clustered_df, |
| lat=lat_col, |
| lon=lon_col, |
| color="Cluster", |
| hover_name=code_col, |
| hover_data=[region_col], |
| zoom=5, |
| height=600, |
| ) |
| fig.update_layout(mapbox_style="open-street-map") |
| st.plotly_chart(fig) |
|
|
| |
| st.download_button( |
| label="Download clustered Excel file", |
| data=to_excel(clustered_df), |
| file_name="clustered_sites.xlsx", |
| mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", |
| on_click="ignore", |
| type="primary", |
| ) |
|
|