Spaces:
Sleeping
Sleeping
| from datetime import date | |
| from typing import TypedDict | |
| import pandas as pd | |
| import streamlit as st | |
| from config import AppConfig | |
| from utils.date_utils import get_reporting_year | |
| from utils.summary import ( | |
| create_multiindex_columns, | |
| create_overall_summary, | |
| create_summary_by_station_and_position, | |
| ) | |
| from utils.timing import timer | |
| class DatasetMetadata(TypedDict): | |
| total_records: int | |
| date_range: dict[str, date] | |
| years: list[int] | |
| stations: int | |
| records_by_year: dict[int, int] | |
| reporting_year_end_month: int | |
| class DataManager: | |
| def __init__(self, config: AppConfig): | |
| self.config = config | |
| self._data_cache = None | |
| self._metadata: DatasetMetadata | None = None | |
| self._all_sectors: list[str] | None = None | |
| self._all_stations: list[str] | None = None | |
| self._initialize_complete_lists() | |
| def _initialize_complete_lists(self) -> None: | |
| """Initialize complete lists of sectors and stations from raw data""" | |
| try: | |
| raw_df = get_raw_data(self.config.DATA_FILE_PATH) | |
| # Handle sectors | |
| sectors = raw_df["Sector"].dropna().unique().tolist() | |
| self._all_sectors = sorted(sectors) | |
| # Handle stations - convert to float first to standardize numeric format | |
| stations = raw_df["Station_Number"].dropna() | |
| stations = stations.astype(float).astype(str).unique().tolist() | |
| self._all_stations = sorted(stations, key=lambda x: float(x)) | |
| except Exception as e: | |
| st.error(f"Failed to initialize complete lists: {str(e)}") | |
| self._all_sectors = [] | |
| self._all_stations = [] | |
| def all_sectors(self) -> list[str]: | |
| """Get complete list of all sectors in the dataset""" | |
| if self._all_sectors is None: | |
| self._initialize_complete_lists() | |
| return self._all_sectors if self._all_sectors is not None else [] | |
| def all_stations(self) -> list[str]: | |
| """Get complete list of all stations in the dataset""" | |
| if self._all_stations is None: | |
| self._initialize_complete_lists() | |
| return self._all_stations if self._all_stations is not None else [] | |
| def metadata(self) -> DatasetMetadata | None: | |
| if self._metadata is None: | |
| self._load_metadata() | |
| return self._metadata | |
| def _load_metadata(self) -> None: | |
| try: | |
| raw_df = get_raw_data(self.config.DATA_FILE_PATH) | |
| self._metadata = get_dataset_metadata( | |
| raw_df, self.config.DEFAULT_REPORTING_MONTH | |
| ) | |
| except Exception as e: | |
| st.error(f"Failed to load dataset metadata: {str(e)}") | |
| self._metadata = None | |
| def _load_data_internal( | |
| self, | |
| reporting_month: int, | |
| start_date: date | None = None, | |
| end_date: date | None = None, | |
| ) -> dict: | |
| """Internal method to load and process data""" | |
| raw_df = get_raw_data(self.config.DATA_FILE_PATH) | |
| raw_df = raw_df[raw_df["Station_Number"].notna()] | |
| # Get full dataset date range for the date input controls | |
| full_dataset_metadata = get_dataset_metadata(raw_df, reporting_month) | |
| # Apply date filters if provided | |
| if start_date and end_date: | |
| raw_df = filter_data_by_dates(raw_df, start_date, end_date) | |
| # Add reporting year based on provided reporting_month or default | |
| if reporting_month is not None: | |
| raw_df["Reporting_Year"] = raw_df["Activity_Start_Date_Time"].apply( | |
| lambda x: get_reporting_year(x, reporting_month) | |
| ) | |
| # Apply exclusion filters if they exist in session state | |
| if ( | |
| "persistent_excluded_sectors" in st.session_state | |
| and st.session_state.persistent_excluded_sectors | |
| ): | |
| raw_df = raw_df[ | |
| ~raw_df["Sector"].isin(st.session_state.persistent_excluded_sectors) | |
| ] | |
| if ( | |
| "persistent_excluded_stations" in st.session_state | |
| and st.session_state.persistent_excluded_stations | |
| ): | |
| # Convert station numbers to standardized string format for comparison | |
| df_stations = raw_df["Station_Number"].astype(float).astype(str) | |
| excluded_stations = [ | |
| str(float(s)) for s in st.session_state.persistent_excluded_stations | |
| ] | |
| raw_df = raw_df[~df_stations.isin(excluded_stations)] | |
| downloads = prepare_downloads(raw_df) | |
| return { | |
| "raw_df": raw_df, | |
| "downloads": downloads, | |
| "full_dataset_metadata": full_dataset_metadata, | |
| } | |
| def _get_empty_data_structure(self) -> dict: | |
| """Return empty data structure for error cases""" | |
| return { | |
| "raw_df": pd.DataFrame(), | |
| "downloads": {"summary": {}, "raw": {}}, | |
| "full_dataset_metadata": { | |
| "total_records": 0, | |
| "date_range": {"start": None, "end": None}, | |
| "years": [], | |
| "stations": 0, | |
| "records_by_year": {}, | |
| }, | |
| } | |
| def load_data( | |
| self, | |
| start_date: date | None = None, | |
| end_date: date | None = None, | |
| reporting_month: int | None = None, | |
| force_refresh: bool = False, | |
| ) -> dict: | |
| """Load data with improved error handling and caching""" | |
| if force_refresh: | |
| st.cache_data.clear() | |
| try: | |
| # Ensure we have the latest exclusions | |
| excluded_sectors = st.session_state.get("persistent_excluded_sectors", []) | |
| excluded_stations = st.session_state.get("persistent_excluded_stations", []) | |
| # Update session state with current exclusions | |
| st.session_state.persistent_excluded_sectors = excluded_sectors | |
| st.session_state.persistent_excluded_stations = excluded_stations | |
| return self._load_data_internal( | |
| reporting_month=reporting_month | |
| if reporting_month | |
| else self.config.DEFAULT_REPORTING_MONTH, | |
| start_date=start_date, | |
| end_date=end_date, | |
| ) | |
| except Exception as e: | |
| st.error(f"Failed to load data: {str(e)}") | |
| return self._get_empty_data_structure() | |
| def get_raw_data(file_path: str) -> pd.DataFrame: | |
| """Load raw data from parquet file""" | |
| return pd.read_parquet(file_path) | |
| def get_dataset_metadata(df: pd.DataFrame, reporting_month: int) -> DatasetMetadata: | |
| """Generate metadata about the dataset""" | |
| return { | |
| "total_records": len(df), | |
| "date_range": { | |
| "start": df["Activity_Start_Date_Time"].min().date(), | |
| "end": df["Activity_Start_Date_Time"].max().date(), | |
| }, | |
| "years": sorted(df["Activity_Start_Date_Time"].dt.year.unique()), | |
| "stations": df["Station_Number"].nunique(), | |
| "records_by_year": ( | |
| df.groupby(df["Activity_Start_Date_Time"].dt.year).size().to_dict() | |
| ), # type: ignore | |
| "reporting_year_end_month": reporting_month, | |
| } | |
| def filter_data_by_dates( | |
| df: pd.DataFrame, start_date: date, end_date: date | |
| ) -> pd.DataFrame: | |
| """Filter dataframe by date range""" | |
| try: | |
| df["Activity_Start_Date_Time"] = pd.to_datetime(df["Activity_Start_Date_Time"]) | |
| # Convert start_date to start of day and end_date to end of day | |
| start_datetime = pd.Timestamp(start_date).normalize() | |
| end_datetime = ( | |
| pd.Timestamp(end_date) + pd.Timedelta(days=1) - pd.Timedelta(microseconds=1) | |
| ) | |
| filtered_df = df[ | |
| (df["Activity_Start_Date_Time"] >= start_datetime) | |
| & (df["Activity_Start_Date_Time"] <= end_datetime) | |
| ] | |
| if filtered_df.empty: | |
| st.warning("No data found for the selected date range") | |
| return df | |
| return filtered_df | |
| except Exception as e: | |
| st.error(f"Error filtering data: {str(e)}") | |
| return df | |
| def create_summaries( | |
| raw_df: pd.DataFrame, | |
| ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: | |
| summary_by_station = create_summary_by_station_and_position(raw_df) | |
| overall_summary = create_overall_summary(raw_df) | |
| multiindex_df = create_multiindex_columns(summary_by_station) | |
| return summary_by_station, overall_summary, multiindex_df | |
| def prepare_downloads(raw_df): | |
| return { | |
| "raw": { | |
| "CSV": (raw_df.to_csv(index=False), "csv", "text/csv"), | |
| }, | |
| } | |
| def add_lat_long(raw_df: pd.DataFrame, stations_df: pd.DataFrame) -> pd.DataFrame: | |
| """ | |
| Add latitude and longitude to raw data based on station number. | |
| """ | |
| raw_df["Number"] = raw_df["Station_Number"].astype(float) | |
| raw_df = raw_df.merge( | |
| stations_df[["Number", "Latitude", "Longitude"]], | |
| left_on="Number", | |
| right_on="Number", | |
| how="left", | |
| ) | |
| return raw_df.drop("Number", axis=1) | |
| def get_stations_data() -> pd.DataFrame: | |
| """ | |
| Return stations data as a dataframe with the most recent and earliest sample dates for each station. | |
| """ | |
| raw_df = st.session_state.data["raw_df"] | |
| # Get date ranges for each station in one operation | |
| sample_dates = ( | |
| raw_df.groupby("Station_Number")["Activity_Start_Date_Time"] | |
| .agg(["min", "max", "count"]) | |
| .reset_index() | |
| .rename( | |
| columns={ | |
| "min": "Earliest_Sample", | |
| "max": "Most_Recent_Sample", | |
| "count": "Total_Samples", | |
| } | |
| ) | |
| .astype({"Station_Number": float, "Total_Samples": int}) | |
| ) | |
| # Merge with stations data and format dates | |
| return ( | |
| pd.read_csv("data/Stations-Locations.csv") | |
| .merge(sample_dates, left_on="Number", right_on="Station_Number", how="left") | |
| .drop("Station_Number", axis=1) | |
| .assign( | |
| Most_Recent_Sample=lambda x: pd.to_datetime(x.Most_Recent_Sample).dt.date, | |
| Earliest_Sample=lambda x: pd.to_datetime(x.Earliest_Sample).dt.date, | |
| ) | |
| .dropna(subset=["Total_Samples"]) | |
| ) | |
| def get_analyte_data_with_lat_long(df: pd.DataFrame, analyte: str) -> pd.DataFrame: | |
| """ | |
| Extract and transform data for a specific analyte, adding geographical coordinates. | |
| This function processes raw water quality data by: | |
| 1. Adding latitude/longitude coordinates from stations data | |
| 2. Filtering for a specific analyte | |
| 3. Removing rows with missing values | |
| 4. Aggregating duplicate measurements using mean values | |
| Args: | |
| df (pd.DataFrame): Raw water quality data containing at minimum these columns: | |
| - Station_Number | |
| - Org_Analyte_Name | |
| - Org_Result_Value | |
| - Reporting_Year | |
| analyte (str): Name of the analyte to filter for (e.g., "Temperature, Water") | |
| Returns: | |
| pd.DataFrame: Processed dataframe with columns: | |
| - Activity_Start_Date_Time: Timestamp of measurement | |
| - Station_Number: Monitoring station identifier | |
| - Sector: Geographical sector | |
| - WBID: Waterbody ID | |
| - Sample_Position: Position of sample (e.g., "Surface", "Bottom") | |
| - Activity_Depth: Depth of measurement | |
| - Latitude: Station latitude | |
| - Longitude: Station longitude | |
| - Reporting_Year: Reporting year | |
| - {analyte}: Measured value for the specified analyte | |
| Note: | |
| Duplicate measurements at the same location and time are averaged. | |
| """ | |
| return ( | |
| df.pipe(add_lat_long, get_stations_data()) | |
| .query(f"Org_Analyte_Name == '{analyte}'") | |
| .dropna(subset=["Org_Result_Value"]) | |
| .pivot_table( | |
| index=[ | |
| "Activity_Start_Date_Time", | |
| "Station_Number", | |
| "Sector", | |
| "WBID", | |
| "Sample_Position", | |
| "Activity_Depth", | |
| "Latitude", | |
| "Longitude", | |
| "Reporting_Year", | |
| ], | |
| values="Org_Result_Value", | |
| aggfunc="mean", | |
| observed=True, | |
| ) | |
| .reset_index() | |
| .rename(columns={"Org_Result_Value": analyte}) | |
| ) | |
| def load_seasonal_data(raw_df, analyte): | |
| """Load and prepare data for seasonal trends analysis""" | |
| return get_analyte_data_with_lat_long(raw_df, analyte) | |