waterdb / components.py
github-actions[bot]
Deploy from GitHub Actions
7f2633f
import calendar
import datetime
import io
from typing import List, Literal, Optional
import pandas as pd
import plotly.express as px
import streamlit as st
from matplotlib import pyplot as plt
from plots.correlation import plot_parameter_correlations
from utils.data_loading import DatasetMetadata
from utils.date_utils import format_reporting_period, get_reporting_year
from utils.timing import timer
def render_sidebar_position_filter_selectbox(
key_prefix: str = "", default: str = "All"
) -> str:
"""Render sample position filter"""
return st.sidebar.selectbox(
"Sample Position:",
["All", "Surface", "Bottom"],
index=["All", "Surface", "Bottom"].index(default),
key=f"{key_prefix}_position_filter",
)
@timer(include_params=False)
def render_sidebar_analyte_multiselect(
all_analytes: List[str],
default_analytes: Optional[List[str]] = None,
key_prefix: str = "",
help_text: str = "Choose one or more analytes to display.",
) -> List[str]:
"""Render analyte multi-select"""
if default_analytes is None:
default_analytes = []
# Filter out any default values that aren't in the options list
valid_defaults = [
analyte for analyte in default_analytes if analyte in all_analytes
]
widget_key = f"{key_prefix}_analyte_select"
if f"{widget_key}_previous_value" not in st.session_state:
values = valid_defaults
else:
values = st.session_state[f"{widget_key}_previous_value"]
def on_analyte_multiselect_change():
st.session_state[f"{widget_key}_previous_value"] = st.session_state[widget_key]
return st.sidebar.multiselect(
"Select Parameters:",
default=values,
options=all_analytes,
key=widget_key,
help=help_text,
on_change=on_analyte_multiselect_change,
)
def render_sidebar_analyte_pills(
all_analytes: List[str],
default_analytes: Optional[List[str]] = None,
key_prefix: str = "",
help_text: str = "Choose one or more analytes to display.",
) -> List[str | None]:
"""Render analyte pills"""
if default_analytes is None:
default_analytes = []
# Filter out any default values that aren't in the options list
valid_defaults = [
analyte for analyte in default_analytes if analyte in all_analytes
]
widget_key = f"{key_prefix}_analyte_select"
if f"{widget_key}_previous_value" not in st.session_state:
values = valid_defaults
else:
values = st.session_state[f"{widget_key}_previous_value"]
def on_analyte_pills_change():
st.session_state[f"{widget_key}_previous_value"] = st.session_state[widget_key]
return st.sidebar.pills(
"Select Parameters:",
selection_mode="multi",
default=values,
options=all_analytes,
key=widget_key,
help=help_text,
on_change=on_analyte_pills_change,
)
def render_filtered_data_preview(
df: pd.DataFrame,
display_columns: List[str],
set_index_col: str | None = None,
height: int = 300,
) -> None:
"""Render preview of filtered dataset"""
if set_index_col:
df = df.set_index(set_index_col)
else:
df = df.reset_index()
st.markdown(f"**{len(df):,}** records")
st.dataframe(
df[display_columns].style.format(precision=2),
use_container_width=True,
height=height,
hide_index=True,
)
@timer(include_params=False)
def render_stations_map(stations_df: pd.DataFrame) -> None:
"""Render the interactive stations map"""
# Calculate center coordinates and zoom level
center_lat = stations_df["Latitude"].mean()
center_lon = stations_df["Longitude"].mean()
# Calculate zoom based on coordinate spread
lat_range = stations_df["Latitude"].max() - stations_df["Latitude"].min()
lon_range = stations_df["Longitude"].max() - stations_df["Longitude"].min()
zoom = min(
11, max(8, 12 - max(lat_range, lon_range) * 5)
) # Adjust multiplier as needed
fig = px.scatter_mapbox(
stations_df,
lat="Latitude",
lon="Longitude",
hover_data={
"Number": True,
"Sector": True,
"WBID": True,
"Latitude": False,
"Longitude": False,
"Most_Recent_Sample": True,
"Total_Samples": True,
},
hover_name="Name",
height=525,
labels={
"Number": "Station Number",
"Sector": "Sector",
"WBID": "WBID",
"Most_Recent_Sample": "Last Sample",
"Total_Samples": "Total Samples",
},
)
fig.update_layout(
mapbox_style="carto-positron",
margin={"r": 0, "t": 0, "l": 0, "b": 0},
mapbox=dict(center=dict(lat=center_lat, lon=center_lon), zoom=zoom),
)
st.plotly_chart(fig, use_container_width=True)
def render_dataset_metadata(
metadata: DatasetMetadata, min_date: datetime.date, max_date: datetime.date
) -> None:
"""Render dataset metadata section"""
st.markdown("#### Dataset Metadata")
st.markdown(f"**Total Records:** {metadata['total_records']:,}")
current_start = st.session_state.get("start_date") or min_date
current_end = st.session_state.get("end_date") or max_date
date_range_text = f"**Date Range:** {current_start.strftime('%Y-%m-%d')} to {current_end.strftime('%Y-%m-%d')}"
st.markdown(date_range_text)
st.markdown(
f"**Reporting Year End Month:** {calendar.month_name[metadata['reporting_year_end_month']]}"
)
st.markdown(f"**Unique Stations:** {metadata['stations']}")
def render_records_by_year(raw_df: pd.DataFrame, reporting_month: int) -> None:
"""Render records by year summary"""
col1, _ = st.columns([3, 1])
with col1:
st.markdown("#### Records by Year")
raw_df["Reporting_Year"] = raw_df["Activity_Start_Date_Time"].apply(
lambda x: get_reporting_year(x, reporting_month)
)
reporting_year_counts = raw_df.groupby("Reporting_Year").size().to_dict()
yearly_data = pd.DataFrame(
list(reporting_year_counts.items()),
columns=["Reporting Year", "Records"],
).sort_values("Reporting Year", ascending=False)
yearly_data["Reporting Period"] = yearly_data["Reporting Year"].apply(
lambda x: format_reporting_period(x, reporting_month)
)
yearly_data = yearly_data[["Reporting Year", "Reporting Period", "Records"]]
row_height = 36
max_height = 600
calculated_height = min(len(yearly_data) * row_height, max_height) + 38
st.dataframe(
yearly_data.style.format(
{"Records": "{:,}"}
), # use this instead of column_config for Records column
hide_index=True,
use_container_width=True,
height=calculated_height,
)
@timer(include_params=True)
def render_correlation_plots(
plot_df: pd.DataFrame,
selected_groups: List[str],
subset_by: str,
position_filter: str,
year_filter: str,
analyte_names: List[str],
corr_method: Literal["pearson", "kendall", "spearman"],
) -> None:
"""Render correlation plots and associated data views for each group.
Args:
plot_df: DataFrame containing the plot data
selected_groups: List of groups to display
subset_by: Column name to group by
position_filter: Sample position filter value
year_filter: Year filter value
analyte_names: List of analytes to include in correlation
corr_method: Correlation method to use
"""
cols = st.columns(2)
for idx, subset in enumerate(selected_groups):
subset_df = plot_df[plot_df[subset_by] == subset]
if not subset_df.empty:
# Generate and display correlation plot
fig, pivot_df = plot_parameter_correlations(
subset_df,
analyte_names,
subset_by,
subset,
position_filter,
corr_method=corr_method, # type: ignore
)
cols[idx % 2].pyplot(fig)
plt.close()
# Create expandable data view section
with cols[idx % 2].expander(f"View {subset_by}: {subset} Data"):
display_columns = [
"Activity_Start_Date_Time",
"Name",
"Station_Number",
"Sample_Position",
"Org_Analyte_Name",
"Org_Result_Value",
"Org_Result_Unit",
"Reporting_Year",
]
st.markdown("##### Raw Data Preview")
render_filtered_data_preview(subset_df, display_columns)
# Add CSV download button
csv_buffer = io.StringIO()
subset_df.to_csv(csv_buffer, index=False)
st.download_button(
label=f"Download {subset} Data (CSV)",
data=csv_buffer.getvalue(),
file_name=f"{subset_by}-{subset}_pos-{position_filter}_year-{year_filter}_correlation.csv",
mime="text/csv",
)
st.markdown("---")
st.markdown("##### Pivot Table")
st.dataframe(pivot_df, use_container_width=True)
def get_reporting_year_info_message(reporting_month: int, example_year: int) -> str:
"""Generate standardized info message about reporting year grouping."""
# Calculate start month (if reporting month is November, start month is December of previous year)
start_month = (reporting_month % 12) + 1
start_year = example_year - 1 if reporting_month == 12 else example_year
return f"""
📅 **Data is grouped by reporting years**:
- Each reporting year ends in **{calendar.month_name[reporting_month]}**. This can be changed in *Settings*.
- Example: Reporting year **{example_year}** covers
**{calendar.month_abbr[start_month]} {start_year}**
through **{calendar.month_abbr[reporting_month]} {example_year}**.
"""
def render_quick_action_buttons() -> None:
col1, col2, col3 = st.columns(3)
with col1:
st.button(
"📊 View by Sector",
help="Analyze water quality trends by sector",
use_container_width=True,
on_click=lambda: st.switch_page("pages/sector_trends.py"),
disabled=True,
)
with col2:
st.button(
"💧 View by Waterbody",
help="Explore data grouped by waterbody",
use_container_width=True,
on_click=lambda: st.switch_page("pages/seasonal_trends.py"),
disabled=True,
)
with col3:
st.button(
"🔍 Custom Analysis",
help="Create custom data visualizations",
use_container_width=True,
on_click=lambda: st.switch_page("pages/parameter_correlations.py"),
disabled=True,
)