waterdb / plots /base.py
github-actions[bot]
Deploy from GitHub Actions
db5d970
import math
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.figure import Figure
COLOR_SCALE = [
"#6D3E91",
"#C05917",
"#58AC8C",
"#286BBB",
"#883039",
"#BC8E5A",
"#00295B",
"#C15065",
"#18470F",
"#9A5129",
"#E56E5A",
"#A2559C",
"#38AABA",
"#578145",
"#970046",
"#00847E",
"#B13507",
"#4C6A9C",
"#CF0A66",
"#00875E",
"#B16214",
"#8C4569",
"#3B8E1D",
"#D73C50",
]
def plot_calendar_heatmap(
df: pd.DataFrame,
analyte: str,
colormap: str | None = None,
position_filter: str = "All",
) -> Figure:
data = df[df["Org_Analyte_Name"] == analyte].copy()
if data.empty:
raise ValueError(
f"No data available for {analyte} with position filter: {position_filter}"
)
result_unit = data["Org_Result_Unit"].iloc[0] if not data.empty else ""
data["Year"] = data["Activity_Start_Date_Time"].dt.year
data["Month"] = data["Activity_Start_Date_Time"].dt.month
pivot_data = data.pivot_table(
values="Org_Result_Value", index="Year", columns="Month", aggfunc="mean"
)
# Choose appropriate colormap based on analyte type
if analyte in ["Fecal Coliform (MPN)"]:
cmap = "viridis" # Blue-green-yellow
elif analyte in ["Temperature, Water"]:
cmap = "coolwarm"
elif analyte in ["Dissolved Oxygen"]:
cmap = "RdYlBu"
elif analyte in ["Total Nitrogen", "Total Phosphorus"]:
cmap = "GnBu" # Green-Blue
elif analyte in ["Depth, Secchi Disk Depth"]:
cmap = "Blues_r"
else:
cmap = "Blues" # Default blue gradient
# If colormap is set, override the analyte-specific default
if colormap:
cmap = colormap
fig, ax = plt.subplots(figsize=(6, len(pivot_data) * 0.5))
# Create heatmap
sns.heatmap(
pivot_data,
cmap=cmap,
annot=True,
fmt=".2f",
cbar_kws={"label": result_unit},
annot_kws={"size": 6},
)
if position_filter == "All":
position_filter = "Surface and Bottom"
ax.set_title(
f"Monthly Averages: {analyte} ({position_filter.lower()})", fontsize=10, pad=10
)
ax.tick_params(axis="both", which="major", labelsize=7)
ax.set_xlabel("Month", fontsize=6)
ax.set_ylabel("Year", fontsize=6)
# Get the colorbar and adjust its label size
colorbar = ax.collections[0].colorbar
colorbar.ax.tick_params(labelsize=7) # type: ignore
colorbar.set_label(result_unit, size=7) # type: ignore
return fig
def plot_do_timeseries(
df: pd.DataFrame,
period: str = "Yearly",
sector: str = "All",
epa_thresh: float = 4.8,
) -> Figure:
"""
Create a time series plot of dissolved oxygen levels for surface and bottom measurements.
Reference:
https://www.hudsonriver.org/ccmp/soe/water-quality/do
Parameters:
-----------
df : pd.DataFrame
Filtered dataframe containing dissolved oxygen measurements
period : str
'yearly' or 'monthly' aggregation period
epa_thresh : float
EPA threshold value for DO in mg/L
Returns:
--------
Figure
Matplotlib figure containing the plot
"""
period = period.lower()
# Filter for DO data and pivot for surface/bottom
do_data = df[
(df["Org_Analyte_Name"] == "Dissolved Oxygen")
& (df["Sample_Position"].isin(["Surface", "Bottom"]))
].copy()
# Create time grouping based on period
if period == "yearly":
do_data["Period"] = do_data["Reporting_Year"]
else: # monthly
do_data["Period"] = pd.to_datetime(
do_data["Activity_Start_Date_Time"]
).dt.to_period("M")
do_data["Period_Start"] = do_data["Period"].dt.to_timestamp()
# Calculate means for each position and period
means = (
do_data.groupby(["Period", "Sample_Position"], observed=True)[
"Org_Result_Value"
]
.mean()
.reset_index()
.pivot(index="Period", columns="Sample_Position", values="Org_Result_Value")
)
# Create figure
fig, ax = plt.subplots(figsize=(15, 8))
# Convert Period index to proper format for plotting
if period == "yearly":
x_values = np.array(means.index.astype(float)) # Explicitly create numpy array
else:
# Convert to numpy array of datetime64
x_values = np.array(
[pd.Period(idx).to_timestamp() for idx in means.index],
dtype="datetime64[ns]",
)
# Plot connecting lines only (no markers)
for i, (idx, row) in enumerate(means.iterrows()):
x_val = x_values[i]
ax.plot(
[x_val, x_val], # Use scalar value instead of list
[row["Bottom"], row["Surface"]],
color="lightgray",
linewidth=1,
zorder=1,
solid_capstyle="round",
)
# Calculate dynamic point size based on number of points
n_points = len(x_values)
base_size = 80 # Maximum point size
min_size = 20 # Minimum point size
# Exponential decay formula: size decreases as number of points increases
point_size = max(
min_size,
base_size * math.exp(-0.0015 * n_points),
)
# Update scatter plot styling
surface_scatter = ax.scatter(
x_values,
means["Surface"],
color="#5AA4D8",
s=point_size,
zorder=2,
label="Surface",
edgecolors="white",
linewidth=1,
alpha=0.9,
)
bottom_scatter = ax.scatter(
x_values,
means["Bottom"],
color="#1B4B8A",
s=point_size,
zorder=2,
label="Bottom",
edgecolors="white",
linewidth=1,
alpha=0.9,
)
# Update EPA threshold line
threshold_line = ax.axhline(
y=epa_thresh,
color="#FF8C00",
linestyle="--",
alpha=0.9,
linewidth=1,
label=f"EPA threshold: {epa_thresh} mg/L",
zorder=0,
)
# Customize legend
ax.legend(
handles=[surface_scatter, bottom_scatter, threshold_line],
loc="upper right",
frameon=False,
ncol=1, # Stack legend items vertically
bbox_to_anchor=(1.0, 1.0), # Position at top right
handletextpad=0.5, # Reduce space between handle and text
)
# Customize spines - only show bottom spine
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["bottom"].set_color("black")
ax.spines["bottom"].set_linewidth(0.5)
# Customize plot with modified grid and axis settings
ax.set_xlabel("Year" if period == "yearly" else "Month")
ax.set_ylabel("Dissolved Oxygen (mg/L)")
ax.set_title("Long-term Dissolved Oxygen Trends")
ax.grid(True, axis="y", alpha=0.15, linestyle="-", color="gray")
# Set y-axis limits with some padding
ymin = max(int(min(means["Bottom"].min(), epa_thresh) * 0.9) - 1, 0)
# ymin = 0
ymax = means["Surface"].max() * 1.1
ax.set_ylim(ymin, ymax)
yticks = np.arange(ymin, ymax, 2)
ax.set_yticks(yticks)
# Remove tick marks but keep labels
ax.tick_params(axis="y", which="both", length=0)
# Adjust x-axis ticks and limits
if period == "monthly":
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
ax.xaxis.set_major_locator(mdates.YearLocator())
plt.xticks(rotation=0)
# Convert to datetime for padding
start_date = mdates.date2num(
pd.Timestamp(min(x_values)) - pd.DateOffset(months=1)
)
end_date = mdates.date2num(
pd.Timestamp(max(x_values)) + pd.DateOffset(months=1)
)
ax.set_xlim(mdates.num2date(start_date), mdates.num2date(end_date))
else:
# For yearly data, ensure whole number ticks but month-based padding
min_year = float(np.floor(min(x_values)))
max_year = float(np.ceil(max(x_values)))
# Set whole number ticks
years = np.arange(min_year, max_year + 1)
ax.set_xticks(years)
# Set limits with one month padding
ax.set_xlim(
min_year - 0.083, max_year + 0.083
) # ~1/12 of a year for month padding
# Move y-axis labels to the left of the gridlines
ax.yaxis.tick_left()
ax.yaxis.set_label_position("left")
plt.tight_layout()
return fig
def plot_do_scatter(
df: pd.DataFrame,
sector: str = "All",
thresh: float = 3.0,
) -> Figure:
"""
Create a scatter plot of all dissolved oxygen measurements.
Parameters:
-----------
df : pd.DataFrame
Filtered dataframe containing dissolved oxygen measurements
sector : str
Sector to filter by, or 'All' for all sectors
thresh : float
Threshold value for DO in mg/L
Returns:
--------
Figure
Matplotlib figure containing the plot
"""
# Filter for DO data
do_data = df[
(df["Org_Analyte_Name"] == "Dissolved Oxygen")
& (df["Sample_Position"].isin(["Surface", "Bottom"]))
].copy()
# Create figure with specific dimensions
fig, ax = plt.subplots(figsize=(15, 8))
# Plot surface and bottom measurements with smaller points
surface_data = do_data[do_data["Sample_Position"] == "Surface"]
bottom_data = do_data[do_data["Sample_Position"] == "Bottom"]
# Plot points
ax.scatter(
surface_data["Activity_Start_Date_Time"],
surface_data["Org_Result_Value"],
color="#1f77b4", # Darker blue for surface
s=25,
alpha=0.5,
label="Surface",
zorder=2,
)
ax.scatter(
bottom_data["Activity_Start_Date_Time"],
bottom_data["Org_Result_Value"],
color="#7fbf7b", # Muted green for bottom
s=25,
alpha=0.5,
label="Bottom",
zorder=2,
)
# Add Hurricane Michael vertical line and annotation if within date range
hurricane_date = pd.Timestamp("2018-10-10")
# Get the date range of the plotted data
data_start = min(do_data["Activity_Start_Date_Time"])
data_end = max(do_data["Activity_Start_Date_Time"])
# Only add hurricane line and annotation if the date falls within the data range
if data_start <= hurricane_date <= data_end:
# Get y-axis limits for line placement
ymin, ymax = ax.get_ylim()
line_height = ymax * 0.95
# Add vertical line with dot at top
ax.axvline(
x=hurricane_date, # type: ignore
color="gray",
linestyle="-",
alpha=0.6,
linewidth=1,
ymin=0,
ymax=line_height / ymax,
zorder=1,
)
# Add dot at top of line
ax.scatter(
[hurricane_date], # type: ignore
[line_height],
color="gray",
s=25,
alpha=0.8,
zorder=2,
)
# Add two-line annotation with bold date
ax.annotate(
"Oct 2018",
xy=(hurricane_date, line_height), # type: ignore
xytext=(5, 0),
textcoords="offset points",
ha="left",
va="bottom",
color="gray",
fontsize=10,
weight="bold",
)
ax.annotate(
"Hurricane Michael",
xy=(hurricane_date, line_height), # type: ignore
xytext=(5, -12),
textcoords="offset points",
ha="left",
va="bottom",
color="gray",
fontsize=10,
)
# Add threshold line
ax.axhline(
y=thresh,
color="red",
linestyle=":",
alpha=0.9,
linewidth=1.5,
label=f"Threshold: {thresh} mg/L",
zorder=1,
)
# Customize legend with larger font
ax.legend(
loc="upper right",
frameon=True,
ncol=1,
bbox_to_anchor=(1.0, 1.0),
handletextpad=0.5,
fontsize=12, # Increased font size
)
# Customize spines - only show bottom spine
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["bottom"].set_color("black")
ax.spines["bottom"].set_linewidth(0.5)
# Set labels and title
title = "DO mg/L"
if sector != "All":
title += f" - {sector}"
ax.set_title(title, fontsize=14) # Increased font size
# Add grid
ax.grid(True, axis="both", alpha=0.15, linestyle="-", color="gray")
# Set y-axis limits with padding
ymin = max(int(min(do_data["Org_Result_Value"].min(), thresh) * 0.9) - 1, 0)
ymax = do_data["Org_Result_Value"].max() * 1.1
ax.set_ylim(ymin, ymax)
yticks = np.arange(ymin, ymax, 2)
ax.set_yticks(yticks)
# Remove tick marks but keep labels
ax.tick_params(axis="y", which="both", length=0)
# Format x-axis
years = mdates.YearLocator()
ax.xaxis.set_major_locator(years)
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
plt.tight_layout()
return fig
def plot_scatter(
df: pd.DataFrame,
parameter: str,
year_range: tuple[int, int],
sector: str = "All",
thresh: float | None = None,
) -> tuple[Figure, pd.DataFrame]:
"""
Create a scatter plot of water quality measurements for any parameter.
Parameters:
-----------
df : pd.DataFrame
Filtered dataframe containing water quality measurements
parameter : str
Name of the parameter to plot (e.g., "Dissolved Oxygen", "Temperature, Water")
sector : str
Sector to filter by, or 'All' for all sectors
thresh : float | None
Optional threshold value to display on plot
Returns:
--------
tuple[Figure, pd.DataFrame]
- Figure: Matplotlib figure containing the scatter plot
- DataFrame: Filtered dataframe containing the parameter data used in the plot
"""
# Filter for parameter data
param_data = df[
(df["Org_Analyte_Name"] == parameter)
& (df["Sample_Position"].isin(["Surface", "Bottom"]))
].copy()
if param_data.empty:
raise ValueError(f"No data found for parameter: {parameter}")
# Get the unit for y-axis label
unit = param_data["Org_Result_Unit"].iloc[0]
# Create figure with specific dimensions
fig, ax = plt.subplots(figsize=(15, 8))
# Plot surface and bottom measurements
surface_data = param_data[param_data["Sample_Position"] == "Surface"]
bottom_data = param_data[param_data["Sample_Position"] == "Bottom"]
# Determine if log scale should be used
log_scale_parameters = [
"Turbidity",
"Fecal Coliform (MPN)",
"Total Nitrogen",
"Total Phosphorus",
"Color",
]
log_scale = parameter in log_scale_parameters
if log_scale:
ax.set_yscale("log")
ax.yaxis.set_major_formatter(plt.ScalarFormatter()) # type: ignore
# For log scale, set limits based on order of magnitude
ymin = max(
param_data["Org_Result_Value"].min() * 0.5, 0.1
) # Don't go below 0.1
ymax = param_data["Org_Result_Value"].max() * 2
if thresh is not None:
ymin = min(ymin, thresh * 0.5)
ax.set_ylim(ymin, ymax)
# Generate log-spaced ticks
log_ymin = np.floor(np.log10(ymin))
log_ymax = np.ceil(np.log10(ymax))
yticks = np.logspace(log_ymin, log_ymax, int(log_ymax - log_ymin) + 1)
ax.set_yticks(yticks)
ax.yaxis.set_major_formatter(plt.ScalarFormatter()) # type: ignore
ax.yaxis.set_minor_formatter(plt.NullFormatter()) # type: ignore
else:
# Existing linear scale code
ymin = param_data["Org_Result_Value"].min() * 0.9
ymax = param_data["Org_Result_Value"].max() * 1.1
if thresh is not None:
ymin = min(ymin, thresh * 0.9)
ax.set_ylim(ymin, ymax)
# Set y-axis ticks for linear scale
tick_range = ymax - ymin
if tick_range > 10:
tick_spacing = 2.0
elif tick_range > 5:
tick_spacing = 1.0
else:
tick_spacing = 0.5
yticks = np.arange(np.floor(ymin), np.ceil(ymax), tick_spacing)
ax.set_yticks(yticks)
# Plot points and collect legend handles/labels
handles = []
labels = []
# Always plot surface data
surface_scatter = ax.scatter(
surface_data["Activity_Start_Date_Time"],
surface_data["Org_Result_Value"],
color="#1f77b4", # Darker blue for surface
s=25,
alpha=0.5,
label="Surface",
zorder=2,
)
handles.append(surface_scatter)
labels.append("Surface")
# Only plot and add to legend if bottom data exists
if not bottom_data.empty:
bottom_scatter = ax.scatter(
bottom_data["Activity_Start_Date_Time"],
bottom_data["Org_Result_Value"],
color="#7fbf7b", # Muted green for bottom
s=25,
alpha=0.5,
label="Bottom",
zorder=2,
)
handles.append(bottom_scatter)
labels.append("Bottom")
# Add Hurricane Michael vertical line and annotation if within date range
hurricane_date = pd.Timestamp("2018-10-10")
# Get the date range of the plotted data
data_start = min(param_data["Activity_Start_Date_Time"])
data_end = max(param_data["Activity_Start_Date_Time"])
# Only add hurricane line and annotation if the date falls within the data range
if data_start <= hurricane_date <= data_end:
# Get y-axis limits for line placement
ymin, ymax = ax.get_ylim()
line_height = ymax * 0.95
# Add vertical line with dot at top
ax.axvline(
x=hurricane_date, # type: ignore
color="gray",
linestyle="-",
alpha=0.6,
linewidth=1,
ymin=0,
ymax=line_height / ymax,
zorder=1,
)
# Add dot at top of line
ax.scatter(
[hurricane_date], # type: ignore
[line_height],
color="gray",
s=25,
alpha=0.8,
zorder=2,
)
# Add two-line annotation with bold date
ax.annotate(
"Oct 2018",
xy=(hurricane_date, line_height), # type: ignore
xytext=(5, 0),
textcoords="offset points",
ha="left",
va="bottom",
color="gray",
fontsize=10,
weight="bold",
)
ax.annotate(
"Hurricane Michael",
xy=(hurricane_date, line_height), # type: ignore
xytext=(5, -12),
textcoords="offset points",
ha="left",
va="bottom",
color="gray",
fontsize=10,
)
# Add threshold line if specified
if thresh is not None:
threshold_line = ax.axhline(
y=thresh,
color="red",
linestyle=":",
alpha=0.9,
linewidth=1.5,
label=f"Threshold: {thresh} {unit}",
zorder=1,
)
handles.append(threshold_line)
labels.append(f"Threshold: {thresh} {unit}")
# Update legend with collected handles and labels
if parameter not in ["Depth, Secchi Disk Depth", "Temperature, Air"]:
ax.legend(
handles=handles,
labels=labels,
loc="upper right",
frameon=True,
ncol=1,
bbox_to_anchor=(1.0, 1.0),
handletextpad=0.5,
fontsize=12,
)
# Customize spines - only show bottom spine
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["bottom"].set_color("black")
ax.spines["bottom"].set_linewidth(0.5)
# Set labels and title
title = parameter
if sector != "All":
title += f" - {sector}"
ax.set_title(title, fontsize=14)
# ax.set_xlabel("Date", fontsize=12)
ax.set_ylabel(f"{unit}", fontsize=12)
# Add grid for major ticks
ax.grid(True, which="major", axis="both", alpha=0.15, linestyle="-", color="gray")
# Remove tick marks but keep labels
ax.tick_params(axis="y", which="both", length=0)
if year_range[1] - year_range[0] <= 1:
# Major ticks for years
ax.xaxis.set_major_locator(mdates.YearLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
# Minor ticks every month
ax.xaxis.set_minor_locator(mdates.MonthLocator())
ax.xaxis.set_minor_formatter(mdates.DateFormatter("%b"))
ax.grid(True, which="minor", axis="x", alpha=0.1, linestyle="-", color="gray")
# Adjust minor tick label appearance
ax.tick_params(axis="x", which="minor", length=4, labelsize=8)
# Rotate labels for better readability
plt.setp(ax.get_xminorticklabels(), rotation=0)
elif year_range[1] - year_range[0] <= 4:
# Major ticks for years
ax.xaxis.set_major_locator(mdates.YearLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
# Minor ticks every 3 months
ax.xaxis.set_minor_locator(mdates.MonthLocator(bymonth=[1, 4, 7, 10]))
ax.xaxis.set_minor_formatter(mdates.DateFormatter("%b"))
ax.grid(True, which="minor", axis="x", alpha=0.1, linestyle="-", color="gray")
# Adjust minor tick label appearance
ax.tick_params(axis="x", which="minor", length=4, labelsize=8)
# Rotate labels for better readability
plt.setp(ax.get_xminorticklabels(), rotation=0)
elif year_range[1] - year_range[0] <= 12:
# Major ticks for years
ax.xaxis.set_major_locator(mdates.YearLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
# Minor ticks every 6 months
ax.xaxis.set_minor_locator(mdates.MonthLocator(bymonth=[1, 7]))
ax.xaxis.set_minor_formatter(mdates.DateFormatter("%b"))
ax.grid(True, which="minor", axis="x", alpha=0.1, linestyle="-", color="gray")
# Adjust minor tick label appearance
ax.tick_params(axis="x", which="minor", length=4, labelsize=8)
# Rotate labels for better readability
plt.setp(ax.get_xminorticklabels(), rotation=0)
else:
# For longer time spans, keep the original yearly ticks
years = mdates.YearLocator()
ax.xaxis.set_major_locator(years)
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
plt.tight_layout()
return (fig, param_data)