import calendar import datetime import io from typing import List, Literal, Optional import pandas as pd import plotly.express as px import streamlit as st from matplotlib import pyplot as plt from plots.correlation import plot_parameter_correlations from utils.data_loading import DatasetMetadata from utils.date_utils import format_reporting_period, get_reporting_year from utils.timing import timer def render_sidebar_position_filter_selectbox( key_prefix: str = "", default: str = "All" ) -> str: """Render sample position filter""" return st.sidebar.selectbox( "Sample Position:", ["All", "Surface", "Bottom"], index=["All", "Surface", "Bottom"].index(default), key=f"{key_prefix}_position_filter", ) @timer(include_params=False) def render_sidebar_analyte_multiselect( all_analytes: List[str], default_analytes: Optional[List[str]] = None, key_prefix: str = "", help_text: str = "Choose one or more analytes to display.", ) -> List[str]: """Render analyte multi-select""" if default_analytes is None: default_analytes = [] # Filter out any default values that aren't in the options list valid_defaults = [ analyte for analyte in default_analytes if analyte in all_analytes ] widget_key = f"{key_prefix}_analyte_select" if f"{widget_key}_previous_value" not in st.session_state: values = valid_defaults else: values = st.session_state[f"{widget_key}_previous_value"] def on_analyte_multiselect_change(): st.session_state[f"{widget_key}_previous_value"] = st.session_state[widget_key] return st.sidebar.multiselect( "Select Parameters:", default=values, options=all_analytes, key=widget_key, help=help_text, on_change=on_analyte_multiselect_change, ) def render_sidebar_analyte_pills( all_analytes: List[str], default_analytes: Optional[List[str]] = None, key_prefix: str = "", help_text: str = "Choose one or more analytes to display.", ) -> List[str | None]: """Render analyte pills""" if default_analytes is None: default_analytes = [] # Filter out any default values that aren't in the options list valid_defaults = [ analyte for analyte in default_analytes if analyte in all_analytes ] widget_key = f"{key_prefix}_analyte_select" if f"{widget_key}_previous_value" not in st.session_state: values = valid_defaults else: values = st.session_state[f"{widget_key}_previous_value"] def on_analyte_pills_change(): st.session_state[f"{widget_key}_previous_value"] = st.session_state[widget_key] return st.sidebar.pills( "Select Parameters:", selection_mode="multi", default=values, options=all_analytes, key=widget_key, help=help_text, on_change=on_analyte_pills_change, ) def render_filtered_data_preview( df: pd.DataFrame, display_columns: List[str], set_index_col: str | None = None, height: int = 300, ) -> None: """Render preview of filtered dataset""" if set_index_col: df = df.set_index(set_index_col) else: df = df.reset_index() st.markdown(f"**{len(df):,}** records") st.dataframe( df[display_columns].style.format(precision=2), use_container_width=True, height=height, hide_index=True, ) @timer(include_params=False) def render_stations_map(stations_df: pd.DataFrame) -> None: """Render the interactive stations map""" # Calculate center coordinates and zoom level center_lat = stations_df["Latitude"].mean() center_lon = stations_df["Longitude"].mean() # Calculate zoom based on coordinate spread lat_range = stations_df["Latitude"].max() - stations_df["Latitude"].min() lon_range = stations_df["Longitude"].max() - stations_df["Longitude"].min() zoom = min( 11, max(8, 12 - max(lat_range, lon_range) * 5) ) # Adjust multiplier as needed fig = px.scatter_mapbox( stations_df, lat="Latitude", lon="Longitude", hover_data={ "Number": True, "Sector": True, "WBID": True, "Latitude": False, "Longitude": False, "Most_Recent_Sample": True, "Total_Samples": True, }, hover_name="Name", height=525, labels={ "Number": "Station Number", "Sector": "Sector", "WBID": "WBID", "Most_Recent_Sample": "Last Sample", "Total_Samples": "Total Samples", }, ) fig.update_layout( mapbox_style="carto-positron", margin={"r": 0, "t": 0, "l": 0, "b": 0}, mapbox=dict(center=dict(lat=center_lat, lon=center_lon), zoom=zoom), ) st.plotly_chart(fig, use_container_width=True) def render_dataset_metadata( metadata: DatasetMetadata, min_date: datetime.date, max_date: datetime.date ) -> None: """Render dataset metadata section""" st.markdown("#### Dataset Metadata") st.markdown(f"**Total Records:** {metadata['total_records']:,}") current_start = st.session_state.get("start_date") or min_date current_end = st.session_state.get("end_date") or max_date date_range_text = f"**Date Range:** {current_start.strftime('%Y-%m-%d')} to {current_end.strftime('%Y-%m-%d')}" st.markdown(date_range_text) st.markdown( f"**Reporting Year End Month:** {calendar.month_name[metadata['reporting_year_end_month']]}" ) st.markdown(f"**Unique Stations:** {metadata['stations']}") def render_records_by_year(raw_df: pd.DataFrame, reporting_month: int) -> None: """Render records by year summary""" col1, _ = st.columns([3, 1]) with col1: st.markdown("#### Records by Year") raw_df["Reporting_Year"] = raw_df["Activity_Start_Date_Time"].apply( lambda x: get_reporting_year(x, reporting_month) ) reporting_year_counts = raw_df.groupby("Reporting_Year").size().to_dict() yearly_data = pd.DataFrame( list(reporting_year_counts.items()), columns=["Reporting Year", "Records"], ).sort_values("Reporting Year", ascending=False) yearly_data["Reporting Period"] = yearly_data["Reporting Year"].apply( lambda x: format_reporting_period(x, reporting_month) ) yearly_data = yearly_data[["Reporting Year", "Reporting Period", "Records"]] row_height = 36 max_height = 600 calculated_height = min(len(yearly_data) * row_height, max_height) + 38 st.dataframe( yearly_data.style.format( {"Records": "{:,}"} ), # use this instead of column_config for Records column hide_index=True, use_container_width=True, height=calculated_height, ) @timer(include_params=True) def render_correlation_plots( plot_df: pd.DataFrame, selected_groups: List[str], subset_by: str, position_filter: str, year_filter: str, analyte_names: List[str], corr_method: Literal["pearson", "kendall", "spearman"], ) -> None: """Render correlation plots and associated data views for each group. Args: plot_df: DataFrame containing the plot data selected_groups: List of groups to display subset_by: Column name to group by position_filter: Sample position filter value year_filter: Year filter value analyte_names: List of analytes to include in correlation corr_method: Correlation method to use """ cols = st.columns(2) for idx, subset in enumerate(selected_groups): subset_df = plot_df[plot_df[subset_by] == subset] if not subset_df.empty: # Generate and display correlation plot fig, pivot_df = plot_parameter_correlations( subset_df, analyte_names, subset_by, subset, position_filter, corr_method=corr_method, # type: ignore ) cols[idx % 2].pyplot(fig) plt.close() # Create expandable data view section with cols[idx % 2].expander(f"View {subset_by}: {subset} Data"): display_columns = [ "Activity_Start_Date_Time", "Name", "Station_Number", "Sample_Position", "Org_Analyte_Name", "Org_Result_Value", "Org_Result_Unit", "Reporting_Year", ] st.markdown("##### Raw Data Preview") render_filtered_data_preview(subset_df, display_columns) # Add CSV download button csv_buffer = io.StringIO() subset_df.to_csv(csv_buffer, index=False) st.download_button( label=f"Download {subset} Data (CSV)", data=csv_buffer.getvalue(), file_name=f"{subset_by}-{subset}_pos-{position_filter}_year-{year_filter}_correlation.csv", mime="text/csv", ) st.markdown("---") st.markdown("##### Pivot Table") st.dataframe(pivot_df, use_container_width=True) def get_reporting_year_info_message(reporting_month: int, example_year: int) -> str: """Generate standardized info message about reporting year grouping.""" # Calculate start month (if reporting month is November, start month is December of previous year) start_month = (reporting_month % 12) + 1 start_year = example_year - 1 if reporting_month == 12 else example_year return f""" 📅 **Data is grouped by reporting years**: - Each reporting year ends in **{calendar.month_name[reporting_month]}**. This can be changed in *Settings*. - Example: Reporting year **{example_year}** covers **{calendar.month_abbr[start_month]} {start_year}** through **{calendar.month_abbr[reporting_month]} {example_year}**. """ def render_quick_action_buttons() -> None: col1, col2, col3 = st.columns(3) with col1: st.button( "📊 View by Sector", help="Analyze water quality trends by sector", use_container_width=True, on_click=lambda: st.switch_page("pages/sector_trends.py"), disabled=True, ) with col2: st.button( "💧 View by Waterbody", help="Explore data grouped by waterbody", use_container_width=True, on_click=lambda: st.switch_page("pages/seasonal_trends.py"), disabled=True, ) with col3: st.button( "🔍 Custom Analysis", help="Create custom data visualizations", use_container_width=True, on_click=lambda: st.switch_page("pages/parameter_correlations.py"), disabled=True, )