from datetime import date from typing import TypedDict import pandas as pd import streamlit as st from config import AppConfig from utils.date_utils import get_reporting_year from utils.summary import ( create_multiindex_columns, create_overall_summary, create_summary_by_station_and_position, ) from utils.timing import timer class DatasetMetadata(TypedDict): total_records: int date_range: dict[str, date] years: list[int] stations: int records_by_year: dict[int, int] reporting_year_end_month: int class DataManager: def __init__(self, config: AppConfig): self.config = config self._data_cache = None self._metadata: DatasetMetadata | None = None self._all_sectors: list[str] | None = None self._all_stations: list[str] | None = None self._initialize_complete_lists() def _initialize_complete_lists(self) -> None: """Initialize complete lists of sectors and stations from raw data""" try: raw_df = get_raw_data(self.config.DATA_FILE_PATH) # Handle sectors sectors = raw_df["Sector"].dropna().unique().tolist() self._all_sectors = sorted(sectors) # Handle stations - convert to float first to standardize numeric format stations = raw_df["Station_Number"].dropna() stations = stations.astype(float).astype(str).unique().tolist() self._all_stations = sorted(stations, key=lambda x: float(x)) except Exception as e: st.error(f"Failed to initialize complete lists: {str(e)}") self._all_sectors = [] self._all_stations = [] @property def all_sectors(self) -> list[str]: """Get complete list of all sectors in the dataset""" if self._all_sectors is None: self._initialize_complete_lists() return self._all_sectors if self._all_sectors is not None else [] @property def all_stations(self) -> list[str]: """Get complete list of all stations in the dataset""" if self._all_stations is None: self._initialize_complete_lists() return self._all_stations if self._all_stations is not None else [] @property def metadata(self) -> DatasetMetadata | None: if self._metadata is None: self._load_metadata() return self._metadata def _load_metadata(self) -> None: try: raw_df = get_raw_data(self.config.DATA_FILE_PATH) self._metadata = get_dataset_metadata( raw_df, self.config.DEFAULT_REPORTING_MONTH ) except Exception as e: st.error(f"Failed to load dataset metadata: {str(e)}") self._metadata = None def _load_data_internal( self, reporting_month: int, start_date: date | None = None, end_date: date | None = None, ) -> dict: """Internal method to load and process data""" raw_df = get_raw_data(self.config.DATA_FILE_PATH) raw_df = raw_df[raw_df["Station_Number"].notna()] # Get full dataset date range for the date input controls full_dataset_metadata = get_dataset_metadata(raw_df, reporting_month) # Apply date filters if provided if start_date and end_date: raw_df = filter_data_by_dates(raw_df, start_date, end_date) # Add reporting year based on provided reporting_month or default if reporting_month is not None: raw_df["Reporting_Year"] = raw_df["Activity_Start_Date_Time"].apply( lambda x: get_reporting_year(x, reporting_month) ) # Apply exclusion filters if they exist in session state if ( "persistent_excluded_sectors" in st.session_state and st.session_state.persistent_excluded_sectors ): raw_df = raw_df[ ~raw_df["Sector"].isin(st.session_state.persistent_excluded_sectors) ] if ( "persistent_excluded_stations" in st.session_state and st.session_state.persistent_excluded_stations ): # Convert station numbers to standardized string format for comparison df_stations = raw_df["Station_Number"].astype(float).astype(str) excluded_stations = [ str(float(s)) for s in st.session_state.persistent_excluded_stations ] raw_df = raw_df[~df_stations.isin(excluded_stations)] downloads = prepare_downloads(raw_df) return { "raw_df": raw_df, "downloads": downloads, "full_dataset_metadata": full_dataset_metadata, } def _get_empty_data_structure(self) -> dict: """Return empty data structure for error cases""" return { "raw_df": pd.DataFrame(), "downloads": {"summary": {}, "raw": {}}, "full_dataset_metadata": { "total_records": 0, "date_range": {"start": None, "end": None}, "years": [], "stations": 0, "records_by_year": {}, }, } def load_data( self, start_date: date | None = None, end_date: date | None = None, reporting_month: int | None = None, force_refresh: bool = False, ) -> dict: """Load data with improved error handling and caching""" if force_refresh: st.cache_data.clear() try: # Ensure we have the latest exclusions excluded_sectors = st.session_state.get("persistent_excluded_sectors", []) excluded_stations = st.session_state.get("persistent_excluded_stations", []) # Update session state with current exclusions st.session_state.persistent_excluded_sectors = excluded_sectors st.session_state.persistent_excluded_stations = excluded_stations return self._load_data_internal( reporting_month=reporting_month if reporting_month else self.config.DEFAULT_REPORTING_MONTH, start_date=start_date, end_date=end_date, ) except Exception as e: st.error(f"Failed to load data: {str(e)}") return self._get_empty_data_structure() @timer(include_params=True) def get_raw_data(file_path: str) -> pd.DataFrame: """Load raw data from parquet file""" return pd.read_parquet(file_path) @timer(include_params=False) def get_dataset_metadata(df: pd.DataFrame, reporting_month: int) -> DatasetMetadata: """Generate metadata about the dataset""" return { "total_records": len(df), "date_range": { "start": df["Activity_Start_Date_Time"].min().date(), "end": df["Activity_Start_Date_Time"].max().date(), }, "years": sorted(df["Activity_Start_Date_Time"].dt.year.unique()), "stations": df["Station_Number"].nunique(), "records_by_year": ( df.groupby(df["Activity_Start_Date_Time"].dt.year).size().to_dict() ), # type: ignore "reporting_year_end_month": reporting_month, } @timer(include_params=False) def filter_data_by_dates( df: pd.DataFrame, start_date: date, end_date: date ) -> pd.DataFrame: """Filter dataframe by date range""" try: df["Activity_Start_Date_Time"] = pd.to_datetime(df["Activity_Start_Date_Time"]) # Convert start_date to start of day and end_date to end of day start_datetime = pd.Timestamp(start_date).normalize() end_datetime = ( pd.Timestamp(end_date) + pd.Timedelta(days=1) - pd.Timedelta(microseconds=1) ) filtered_df = df[ (df["Activity_Start_Date_Time"] >= start_datetime) & (df["Activity_Start_Date_Time"] <= end_datetime) ] if filtered_df.empty: st.warning("No data found for the selected date range") return df return filtered_df except Exception as e: st.error(f"Error filtering data: {str(e)}") return df @st.cache_data @timer(include_params=False) def create_summaries( raw_df: pd.DataFrame, ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: summary_by_station = create_summary_by_station_and_position(raw_df) overall_summary = create_overall_summary(raw_df) multiindex_df = create_multiindex_columns(summary_by_station) return summary_by_station, overall_summary, multiindex_df @timer(include_params=False) def prepare_downloads(raw_df): return { "raw": { "CSV": (raw_df.to_csv(index=False), "csv", "text/csv"), }, } def add_lat_long(raw_df: pd.DataFrame, stations_df: pd.DataFrame) -> pd.DataFrame: """ Add latitude and longitude to raw data based on station number. """ raw_df["Number"] = raw_df["Station_Number"].astype(float) raw_df = raw_df.merge( stations_df[["Number", "Latitude", "Longitude"]], left_on="Number", right_on="Number", how="left", ) return raw_df.drop("Number", axis=1) @timer(include_params=False) def get_stations_data() -> pd.DataFrame: """ Return stations data as a dataframe with the most recent and earliest sample dates for each station. """ raw_df = st.session_state.data["raw_df"] # Get date ranges for each station in one operation sample_dates = ( raw_df.groupby("Station_Number")["Activity_Start_Date_Time"] .agg(["min", "max", "count"]) .reset_index() .rename( columns={ "min": "Earliest_Sample", "max": "Most_Recent_Sample", "count": "Total_Samples", } ) .astype({"Station_Number": float, "Total_Samples": int}) ) # Merge with stations data and format dates return ( pd.read_csv("data/Stations-Locations.csv") .merge(sample_dates, left_on="Number", right_on="Station_Number", how="left") .drop("Station_Number", axis=1) .assign( Most_Recent_Sample=lambda x: pd.to_datetime(x.Most_Recent_Sample).dt.date, Earliest_Sample=lambda x: pd.to_datetime(x.Earliest_Sample).dt.date, ) .dropna(subset=["Total_Samples"]) ) @timer(include_params=False) def get_analyte_data_with_lat_long(df: pd.DataFrame, analyte: str) -> pd.DataFrame: """ Extract and transform data for a specific analyte, adding geographical coordinates. This function processes raw water quality data by: 1. Adding latitude/longitude coordinates from stations data 2. Filtering for a specific analyte 3. Removing rows with missing values 4. Aggregating duplicate measurements using mean values Args: df (pd.DataFrame): Raw water quality data containing at minimum these columns: - Station_Number - Org_Analyte_Name - Org_Result_Value - Reporting_Year analyte (str): Name of the analyte to filter for (e.g., "Temperature, Water") Returns: pd.DataFrame: Processed dataframe with columns: - Activity_Start_Date_Time: Timestamp of measurement - Station_Number: Monitoring station identifier - Sector: Geographical sector - WBID: Waterbody ID - Sample_Position: Position of sample (e.g., "Surface", "Bottom") - Activity_Depth: Depth of measurement - Latitude: Station latitude - Longitude: Station longitude - Reporting_Year: Reporting year - {analyte}: Measured value for the specified analyte Note: Duplicate measurements at the same location and time are averaged. """ return ( df.pipe(add_lat_long, get_stations_data()) .query(f"Org_Analyte_Name == '{analyte}'") .dropna(subset=["Org_Result_Value"]) .pivot_table( index=[ "Activity_Start_Date_Time", "Station_Number", "Sector", "WBID", "Sample_Position", "Activity_Depth", "Latitude", "Longitude", "Reporting_Year", ], values="Org_Result_Value", aggfunc="mean", observed=True, ) .reset_index() .rename(columns={"Org_Result_Value": analyte}) ) @st.cache_data @timer(include_params=False) def load_seasonal_data(raw_df, analyte): """Load and prepare data for seasonal trends analysis""" return get_analyte_data_with_lat_long(raw_df, analyte)