waterdb / settings.py
github-actions[bot]
Deploy from GitHub Actions
bc24113
import calendar
from datetime import datetime
import streamlit as st
from streamlit.errors import StreamlitAPIException
from admin import AuthManager, login
from components import render_dataset_metadata, render_records_by_year
from config import AppConfig
from dashboard_analytics import log_visit
from utils.data_loading import get_dataset_metadata
from utils.session import ensure_session_initialized
ensure_session_initialized()
log_visit("Settings")
col1, col2 = st.columns([2, 1])
data_manager = st.session_state.data_manager
current_start = st.session_state.get("start_date")
current_end = st.session_state.get("end_date")
reporting_month: int = st.session_state.get(
"reporting_month", AppConfig.DEFAULT_REPORTING_MONTH
)
st.session_state.data = data_manager.load_data(
start_date=current_start, end_date=current_end, reporting_month=reporting_month
)
full_metadata = data_manager.metadata
metadata = get_dataset_metadata(
st.session_state.data["raw_df"], reporting_month=reporting_month
)
# Get the date range from metadata
raw_start = full_metadata["date_range"]["start"]
raw_end = full_metadata["date_range"]["end"]
# Set to first day of start month and last day of end month
min_date = raw_start.replace(day=1)
max_date = raw_end.replace(day=calendar.monthrange(raw_end.year, raw_end.month)[1])
def on_date_change():
"""Callback for date input changes"""
# Ensure these variables are initialized
if (
"dataset_start_date" not in st.session_state
or "dataset_end_date" not in st.session_state
):
return
start = st.session_state.dataset_start_date
end = st.session_state.dataset_end_date
# Check if either date is None or invalid - if so, keep previous values
if start is None or end is None:
st.error("Both start and end dates must be selected")
return
if start > end:
st.error("Start date must be before end date")
return
if start < min_date or end > max_date:
st.error(
f"Dates must be between {min_date.strftime('%m/%d/%Y')} and {max_date.strftime('%m/%d/%Y')}"
)
if start < min_date:
del st.session_state["dataset_start_date"]
if end > max_date:
del st.session_state["dataset_end_date"]
# Force UI update by toggling a session state variable
st.session_state["force_refresh"] = not st.session_state.get(
"force_refresh", False
)
return
# Only update if we have valid dates
st.session_state.start_date = start
st.session_state.end_date = end
with col1:
st.subheader("Reporting Period")
# Date filter controls
filter_col1, filter_col2, filter_col3 = st.columns(3, vertical_alignment="bottom")
# Use existing values from session state, or defaults if not set
current_start = st.session_state.get("start_date", min_date) or min_date
current_end = st.session_state.get("end_date", max_date) or max_date
with filter_col1:
try:
start_date = st.date_input(
"Start Date",
value=current_start,
min_value=min_date,
max_value=max_date,
format="MM/DD/YYYY",
key="dataset_start_date",
on_change=on_date_change,
)
except StreamlitAPIException:
start_date = current_start
st.error(
f"Date must be between {min_date.strftime('%m/%d/%Y')} and {max_date.strftime('%m/%d/%Y')}"
)
with filter_col2:
try:
end_date = st.date_input(
"End Date",
value=current_end,
min_value=min_date,
max_value=max_date,
format="MM/DD/YYYY",
key="dataset_end_date",
on_change=on_date_change,
)
except StreamlitAPIException:
end_date = current_end
# Show an error message
st.error(
f"Date must be between {min_date.strftime('%m/%d/%Y')} and {max_date.strftime('%m/%d/%Y')}"
)
config = st.session_state.get("config") or AppConfig.from_env()
initial_reporting_month = st.session_state.get(
"reporting_month", config.DEFAULT_REPORTING_MONTH
)
initial_dataset_month = st.session_state.get(
"dataset_reporting_month", initial_reporting_month
)
def on_reporting_month_change():
if "dataset_reporting_month" in st.session_state:
st.session_state.reporting_month = st.session_state.dataset_reporting_month
filter_row2_col1, _ = st.columns([1, 2])
with filter_row2_col1:
reporting_month = st.selectbox(
"Reporting Year End Month",
options=range(1, 13),
format_func=lambda x: datetime(2000, x, 1).strftime("%B"),
index=initial_reporting_month - 1,
key="dataset_reporting_month",
on_change=on_reporting_month_change,
)
st.subheader("Data Exclusions")
exclusion_col1, exclusion_col2 = st.columns(2)
def on_sector_exclusion_change():
"""Callback for sector exclusion changes"""
if "excluded_sectors_widget" in st.session_state:
# Update the persistent storage with widget values
st.session_state.persistent_excluded_sectors = (
st.session_state.excluded_sectors_widget
)
# Reload data with new exclusions
st.session_state.data = st.session_state.data_manager.load_data(
start_date=st.session_state.get("start_date"),
end_date=st.session_state.get("end_date"),
reporting_month=st.session_state.get("reporting_month"),
)
def on_station_exclusion_change():
"""Callback for station exclusion changes"""
if "excluded_stations_widget" in st.session_state:
# Update the persistent storage with widget values
st.session_state.persistent_excluded_stations = (
st.session_state.excluded_stations_widget
)
# Reload data with new exclusions
st.session_state.data = st.session_state.data_manager.load_data(
start_date=st.session_state.get("start_date"),
end_date=st.session_state.get("end_date"),
reporting_month=st.session_state.get("reporting_month"),
)
# Initialize persistent storage if not exists
if "persistent_excluded_sectors" not in st.session_state:
st.session_state.persistent_excluded_sectors = []
if "persistent_excluded_stations" not in st.session_state:
st.session_state.persistent_excluded_stations = []
# Reload data if there are any exclusions
if st.session_state.get("persistent_excluded_sectors") or st.session_state.get(
"persistent_excluded_stations"
):
st.session_state.data = st.session_state.data_manager.load_data(
start_date=st.session_state.get("start_date"),
end_date=st.session_state.get("end_date"),
reporting_month=st.session_state.get("reporting_month"),
)
with exclusion_col1:
# Get complete list of sectors from data manager
all_sectors = st.session_state.data_manager.all_sectors
st.multiselect(
"Exclude Sectors",
options=all_sectors,
default=st.session_state.persistent_excluded_sectors,
help="Select sectors to exclude from all analyses",
key="excluded_sectors_widget",
on_change=on_sector_exclusion_change,
)
with exclusion_col2:
# Get complete list of stations from data manager
all_stations = st.session_state.data_manager.all_stations
st.multiselect(
"Exclude Stations",
options=all_stations,
default=st.session_state.persistent_excluded_stations,
help="Select stations to exclude from all analyses",
key="excluded_stations_widget",
on_change=on_station_exclusion_change,
)
render_dataset_metadata(metadata, min_date, max_date)
render_records_by_year(st.session_state.data["raw_df"], reporting_month)
# Add the login form in a half-width column
login_col1, _ = st.columns([1, 1])
with login_col1:
with st.expander("Admin Login"):
login(AuthManager(config))