import sys from pathlib import Path import contextily as ctx import geopandas as gpd import matplotlib.pyplot as plt import numpy as np import pandas as pd import streamlit as st from matplotlib.colors import LinearSegmentedColormap from matplotlib.figure import Figure from osgeo import gdal from utils.data_loading import timer BASEMAP_PROVIDERS = { "USGS Topo": ctx.providers.USGS.USTopo, # type: ignore "OpenStreetMap": ctx.providers.OpenStreetMap.Mapnik, # type: ignore "CartoDB Light": ctx.providers.CartoDB.Positron, # type: ignore "CartoDB Voyager": ctx.providers.CartoDB.Voyager, # type: ignore "NASAGIBS.ASTER_GDEM_Greyscale_Shaded_Relief": ctx.providers.NASAGIBS.ASTER_GDEM_Greyscale_Shaded_Relief, # type: ignore "OpenTopoMap": ctx.providers.OpenTopoMap, # type: ignore } @timer(include_params=True) def generate_seasonal_plot( data: pd.DataFrame, parameter: str, year_range: list[int], areas: list[str], area_type: str = "wbid", reporting_end_month: int = 10, basemap_provider=ctx.providers.USGS.USTopo, # type: ignore alpha: float = 0.5, show_marks: bool = True, ) -> tuple[Figure, pd.DataFrame, pd.DataFrame]: """ Create seasonal plots of mean parameter values by WBID or Sector. Parameters ---------- data : pd.DataFrame DataFrame containing measurements with lat/long parameter : str Parameter to plot (e.g., "Salinity", "Dissolved Oxygen") year_range : list[int] [start_year, end_year] for data filtering. If same year, single-year plot areas : list[str] List of WBIDs or Sector names to plot area_type : str Either "wbid" or "sector" to specify how to filter the data reporting_end_month : int Last month of reporting year (1-12) basemap_provider : ctx.providers Contextily map provider alpha : float Transparency of basemap show_marks : bool Whether to show station markers on the plot Returns ------- tuple[Figure, pd.DataFrame, pd.DataFrame] - Figure: Matplotlib figure containing the plot - DataFrame: Raw data used in plot - DataFrame: Processed quarterly means """ if area_type == "wbid": shapefile_path = "data/waterbody_ids/Waterbody_IDs_(WBIDs).shp" elif area_type == "sector": shapefile_path = "data/sab_sectors/SAB_Sectors.shp" else: raise ValueError(f"Invalid area_type: {area_type}") # Load and filter areas shapefile areas_gdf = gpd.read_file(shapefile_path) if area_type.lower() == "sector": filtered_areas = areas_gdf[areas_gdf["Sector"].isin(areas)].to_crs("EPSG:3857") else: filtered_areas = areas_gdf[areas_gdf["WBID"].isin(areas)].to_crs("EPSG:3857") # Filter data for year range and areas if area_type.lower() == "sector": year_data = data[ (data["Reporting_Year"].between(year_range[0], year_range[1])) & (data["Sector"].isin(areas)) ].copy() else: year_data = data[ (data["Reporting_Year"].between(year_range[0], year_range[1])) & (data["WBID"].isin(areas)) ].copy() # Add quarter information to year_data before creating stations GeoDataFrame year_data["quarter"] = year_data["Activity_Start_Date_Time"].apply( lambda x: get_quarter(x, reporting_end_month) ) # Create unique station markers for each sector MARKERS = ["o", "s", "^", "X", "*", "P", "<", "p", "h", "8"] sector_markers = { sector: MARKERS[i % len(MARKERS)] for i, sector in enumerate(areas) } # Convert station coordinates to Web Mercator stations = None if show_marks: stations = gpd.GeoDataFrame( # type: ignore year_data, geometry=gpd.points_from_xy(year_data.Longitude, year_data.Latitude), crs="EPSG:4326", ).to_crs("EPSG:3857") # type: ignore # Calculate quarterly means seasonal_means = calculate_quarterly_means( year_data, parameter, reporting_end_month, area_type ) # Create the plot fig = create_quarterly_maps( # type: ignore seasonal_means=seasonal_means, areas_gdf=filtered_areas, parameter=parameter, year_range=year_range, area_type=area_type, reporting_end_month=reporting_end_month, basemap_provider=basemap_provider, alpha=alpha, stations=stations, sector_markers=sector_markers if show_marks else None, ) # Select columns based on area_type area_column = "Sector" if area_type.lower() == "sector" else "WBID" return fig, year_data, seasonal_means[[area_column, "quarter", parameter]] def calculate_quarterly_means( data: pd.DataFrame, parameter: str, reporting_end_month: int, area_type: str = "wbid", ) -> pd.DataFrame: """Calculate quarterly means for the parameter""" # Add quarter information data["quarter"] = data["Activity_Start_Date_Time"].apply( lambda x: get_quarter(x, reporting_end_month) ) # Add month information for completeness check data["month"] = data["Activity_Start_Date_Time"].dt.month # Determine grouping column based on area_type area_column = "Sector" if area_type.lower() == "sector" else "WBID" # Calculate means and track months per quarter quarterly_stats = ( data.groupby([area_column, "quarter"], observed=True) .agg( { "Org_Result_Value": "mean", "month": lambda x: len(set(x)), # Count unique months } ) .reset_index() .rename(columns={"Org_Result_Value": parameter, "month": "months_sampled"}) ) return quarterly_stats def get_quarter(date, reporting_end_month: int) -> str: """Calculate quarter based on reporting year end month""" month = date.month month_offset = (12 - reporting_end_month) % 12 adjusted_month = ((month + month_offset) % 12) or 12 return f"Q{((adjusted_month - 1) // 3) + 1}" def create_quarterly_maps( seasonal_means: pd.DataFrame, areas_gdf: gpd.GeoDataFrame, parameter: str, year_range: list[int], area_type: str, reporting_end_month: int, basemap_provider, alpha: float = 0.5, stations: gpd.GeoDataFrame | pd.DataFrame | None = None, sector_markers: dict | None = None, ) -> Figure: """Create the quarterly map visualization""" fig = plt.figure(figsize=(20, 14)) # Adjust grid spacing to reduce gaps gs = fig.add_gridspec( 2, 2, width_ratios=[1, 1], wspace=0.05, hspace=-0.15, left=0.02, right=0.92, top=0.95, bottom=0.05, ) # Set up color scheme colors = get_parameter_colors(parameter) cmap = LinearSegmentedColormap.from_list("custom", colors, N=100) # Calculate plot bounds bounds = areas_gdf.total_bounds extent = calculate_map_extent(bounds) # Add main title if year_range[0] == year_range[1]: title = f"Seasonal {parameter} Values for {year_range[0]}" else: title = f"Seasonal {parameter} Values ({year_range[0]}-{year_range[1]})" fig.suptitle(title, fontsize=14, y=0.95) # Plot each quarter axes = [] for idx, quarter in enumerate(["Q1", "Q2", "Q3", "Q4"]): ax = fig.add_subplot(gs[idx // 2, idx % 2]) axes.append(ax) plot_quarter( ax=ax, quarter=quarter, seasonal_means=seasonal_means, areas_gdf=areas_gdf, parameter=parameter, year_range=year_range, area_type=area_type, reporting_end_month=reporting_end_month, cmap=cmap, extent=extent, basemap_provider=basemap_provider, alpha=alpha, stations=stations, sector_markers=sector_markers, add_legend=False, # Don't add legend to individual plots ) # Add a single legend for all sector markers if stations are present if stations is not None and sector_markers is not None: # Create dummy scatter plots for legend legend_elements = [] for sector, marker in sector_markers.items(): legend_elements.append( plt.scatter( [], [], marker=marker, color="black", s=25, alpha=0.5, label=sector, ) ) # Add the legend to the figure fig.legend( handles=legend_elements, bbox_to_anchor=(0.90, 0.87), loc="upper left", borderaxespad=0.0, title="Station Locations", ) add_colorbar(fig, seasonal_means, parameter, cmap) return fig def plot_quarter( ax: plt.Axes, # type: ignore quarter: str, seasonal_means: pd.DataFrame, areas_gdf: gpd.GeoDataFrame, parameter: str, year_range: list[int], area_type: str, reporting_end_month: int, cmap: LinearSegmentedColormap, extent: list[float], basemap_provider, alpha: float = 0.5, stations: gpd.GeoDataFrame | pd.DataFrame | None = None, sector_markers: dict | None = None, add_legend: bool = False, ) -> None: """Plot a single quarter's map""" # Get data for this quarter quarter_data = seasonal_means[seasonal_means["quarter"] == quarter] area_column = "Sector" if area_type.lower() == "sector" else "WBID" # Calculate sector means quarter_means = ( quarter_data.groupby(area_column, observed=True) .agg({parameter: ["mean", "min", "max", "count"]}) .reset_index() ) quarter_means.columns = [ area_column, f"{parameter}_mean", f"{parameter}_min", f"{parameter}_max", "count", ] # Print summary statistics print("\nSummary statistics per sector:") print(quarter_means) # Use the mean for plotting plot_data = quarter_means.rename(columns={f"{parameter}_mean": parameter})[ [area_column, parameter] ] try: # Try to fix invalid geometries before dissolving areas_gdf["geometry"] = areas_gdf["geometry"].buffer(0) # type: ignore # Dissolve geometries by sector with a small buffer to avoid topology errors areas_gdf = areas_gdf.dissolve(by="Sector").reset_index() # type: ignore except Exception as e: print(f"\nWarning: Could not dissolve geometries: {str(e)}") # If dissolve fails, take the first geometry for each sector areas_gdf = areas_gdf.groupby("Sector").first().reset_index() # type: ignore # Merge with geometry merged = areas_gdf.merge(plot_data, on=area_column, how="left") print("\nShape of merged data:", merged.shape) if merged.duplicated(subset=[area_column]).any(): print("\nWARNING: Found duplicates after merge!") print( merged[merged.duplicated(subset=[area_column], keep=False)].sort_values( area_column ) ) # Get value range for consistent colormap vmin = 0 vmax = get_parameter_max_value(parameter, seasonal_means[parameter].max()) print(f"\nValue range: {vmin} to {vmax}") print(f"Final data range: {merged[parameter].min()} to {merged[parameter].max()}") # Plot WBIDs/Sectors merged.plot( column=parameter, ax=ax, cmap=cmap, vmin=vmin, vmax=vmax, alpha=0.7, missing_kwds={"color": "lightgrey", "alpha": 0.5}, legend=False, ) # Try primary basemap provider, fall back to CartoDB if it fails try: ctx.add_basemap(ax, source=basemap_provider, zoom=11, alpha=alpha) # type: ignore except Exception as e: st.warning(f"Primary basemap failed ({str(e)}), using fallback provider") try: ctx.add_basemap( ax, source=ctx.providers.CartoDB.Voyager, # type: ignore zoom=11, # type: ignore alpha=alpha, ) except Exception as e2: st.error(f"Fallback basemap also failed: {str(e2)}") # Set map extent ax.set_xlim(extent[0], extent[1]) ax.set_ylim(extent[2], extent[3]) # Get date range for this quarter if year_range[0] == year_range[1]: date_range = get_quarter_dates(quarter, year_range[0], reporting_end_month) title = f"Quarter {quarter[1]} Mean {parameter}\n{date_range}" else: start_date = get_quarter_dates( quarter, year_range[0], reporting_end_month ).split(" - ")[0] end_date = get_quarter_dates(quarter, year_range[1], reporting_end_month).split( " - " )[1] title = f"Quarter {quarter[1]} Mean {parameter}\n{start_date} - {end_date}" # Create title with appropriate padding based on position title_pad = 15 if int(quarter[1]) <= 2 else 5 ax.set_title( title, pad=title_pad, fontsize=10, ) ax.set_axis_off() # Add station markers after the main plot if stations is not None and sector_markers is not None: # Filter stations for this quarter quarter_stations = stations[stations["quarter"] == quarter] # Plot unique stations for each sector for sector in sector_markers: sector_stations = quarter_stations[quarter_stations["Sector"] == sector] # Use 'Station' instead of 'Station_ID' for dropping duplicates station_id_col = "Station_Number" if station_id_col in sector_stations.columns: subset_cols: list[str] = [station_id_col] unique_stations = sector_stations.drop_duplicates(subset=subset_cols) # type: ignore else: # If no station ID column is found, use lat/long to identify unique locations unique_stations = sector_stations.drop_duplicates( # type: ignore subset=["Latitude", "Longitude"] ) # Extract x, y coordinates from the geometry x = [point.x for point in unique_stations.geometry] y = [point.y for point in unique_stations.geometry] # Plot stations with sector-specific marker ax.scatter( x, y, marker=sector_markers[sector], color="black", s=25, alpha=0.5, ) def get_parameter_max_value(parameter: str, data_max: float) -> float: """Get the maximum value for colormap scaling based on parameter""" parameter_limits = { "Salinity": 40, "Dissolved Oxygen": 12, "pH": 9, "Temperature, Water": 35, "Turbidity": None, # Use data max "Total Nitrogen": None, "Total Phosphorus": None, "Fecal Coliform (MPN)": None, } return parameter_limits.get(parameter, data_max) def calculate_map_extent( bounds: np.ndarray, buffer_fraction: float = 0.03 ) -> list[float]: """Calculate map extent with buffer""" x_buffer = (bounds[2] - bounds[0]) * buffer_fraction y_buffer = (bounds[3] - bounds[1]) * buffer_fraction return [ bounds[0] - x_buffer, # xmin bounds[2] + x_buffer, # xmax bounds[1] - y_buffer, # ymin bounds[3] + y_buffer, # ymax ] def get_quarter_dates(quarter: str, year: int, reporting_end_month: int) -> str: """Get date range string for a quarter""" # Calculate first month of reporting year first_month = (reporting_end_month % 12) + 1 # Calculate start month for each quarter quarter_num = int(quarter[1]) start_month = ((first_month - 1 + ((quarter_num - 1) * 3)) % 12) + 1 end_month = ((start_month + 2) % 12) or 12 # Determine correct years for start and end dates start_year = year - 1 if start_month > reporting_end_month else year end_year = start_year if end_month >= start_month else start_year + 1 # Create date objects start_date = pd.Timestamp(f"{start_year}-{start_month:02d}-01") end_date = pd.Timestamp( f"{end_year}-{end_month:02d}-{pd.Timestamp(f'{end_year}-{end_month:02d}').days_in_month}" ) return f"{start_date.strftime('%b %d, %Y')} - {end_date.strftime('%b %d, %Y')}" def add_colorbar( fig: Figure, seasonal_means: pd.DataFrame, parameter: str, cmap: LinearSegmentedColormap, ) -> None: """Add colorbar to the figure""" # Get value range vmin = seasonal_means[parameter].min() vmax = get_parameter_max_value(parameter, seasonal_means[parameter].max()) data_max = seasonal_means[parameter].max() # Create colorbar norm = plt.Normalize(vmin=vmin, vmax=vmax if vmax is not None else data_max) # type: ignore sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) # Get parameter unit unit = get_parameter_unit(parameter) label = f"{parameter} ({unit})" if unit else parameter # Calculate appropriate number of ticks based on data range if vmax is not None: if vmax <= 1: ticks = np.array([0, 0.2, 0.4, 0.6, 0.8, 1.0]) elif vmax <= 10: ticks = np.array([0, 2, 4, 6, 8, 10]) elif vmax <= 50: ticks = np.array([0, 10, 20, 30, 40, 50]) else: ticks = np.linspace(0, vmax, 6) else: # Use data_max with fewer ticks if data_max <= 1: ticks = np.array([0, 0.2, 0.4, 0.6, 0.8, 1.0]) elif data_max <= 10: ticks = np.array([0, 2, 4, 6, 8, 10]) elif data_max <= 50: ticks = np.array([0, 10, 20, 30, 40, 50]) else: ticks = np.linspace(0, np.ceil(data_max / 100) * 100, 6) # Add colorbar to figure fig.colorbar( sm, ax=fig.axes, orientation="vertical", label=label, pad=0.02, fraction=0.015, ticks=ticks, ) def get_parameter_unit(parameter: str) -> str: """Get the unit for a parameter""" parameter_units = { "Salinity": "ppt", "Dissolved Oxygen": "mg/L", "pH": "", "Temperature, Water": "°C", "Turbidity": "NTU", "Total Nitrogen": "mg/L", "Total Phosphorus": "mg/L", "Fecal Coliform (MPN)": "MPN/100mL", } return parameter_units.get(parameter, "") def get_parameter_colors(parameter: str) -> list[str]: """Get the color scheme for a parameter. Parameters that increase in severity with higher values (like temperature) use warm->cool. Parameters that decrease in severity with higher values (like DO) use cool->warm. """ # Default color scheme (blue -> red) for parameters where higher values are concerning default_colors = ["#08519c", "#73a9cf", "#fee090", "#fc8d59", "#d73027"] # Color schemes by parameter type parameter_colors = { # Temperature: cold (blue) to hot (red) "Temperature, Water": ["#d73027", "#fc8d59", "#fee090", "#73a9cf", "#08519c"][ ::-1 ], # DO: low (red) to high (blue) - default scheme "Dissolved Oxygen": default_colors, # pH: low (red) to neutral (green) to high (red) "pH": ["#d73027", "#fc8d59", "#fee090", "#fc8d59", "#d73027"], # Nutrients: low (blue) to high (red) - default scheme "Total Nitrogen": default_colors, "Total Phosphorus": default_colors, # Turbidity: clear (blue) to turbid (red) - default scheme "Turbidity": default_colors, # Bacteria: low (blue) to high (red) - default scheme "Fecal Coliform (MPN)": default_colors, # Salinity: fresh (blue) to saline (red) - default scheme "Salinity": default_colors, } return parameter_colors.get(parameter, default_colors) def debugging_info(data: pd.DataFrame, shapefile_path: str) -> None: # Add debugging information sectors_gdf = gpd.read_file(shapefile_path) # Ensure input data has CRS set if isinstance(data, gpd.GeoDataFrame): if data.crs is None: # Assuming the input coordinates are in WGS84 (EPSG:4326) data.set_crs(epsg=4326, inplace=True) # Ensure shapefile has CRS set and transform to Web Mercator if sectors_gdf.crs is None: sectors_gdf.set_crs(epsg=6439, inplace=True) # Pre-transform to Web Mercator (EPSG:3857) here to avoid issues in plotting function sectors_gdf = sectors_gdf.to_crs(epsg=3857) st.write("Debug Info:") st.write( { "Shapefile CRS": sectors_gdf.crs, "Input Data CRS": data.crs if isinstance(data, gpd.GeoDataFrame) else "Not a GeoDataFrame", "GDAL Version": gdal.VersionInfo() if "osgeo.gdal" in sys.modules else "Not available", "GeoPandas Version": gpd.__version__, "Python Version": sys.version, "File exists": Path(shapefile_path).exists(), "Associated files": list(Path(shapefile_path).parent.glob("*.*")), } )