waterdb / plots /correlation.py
github-actions[bot]
Deploy from GitHub Actions
7f2633f
from typing import Literal
import altair as alt
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import seaborn as sns
import streamlit as st
from matplotlib.figure import Figure
from utils.data_loading import timer
@st.cache_data
@timer(include_params=True)
def plot_parameter_correlations(
df: pd.DataFrame,
analyte_names: list[str],
subset_by: str,
subset: str,
filter_by: str,
threshold: float = 0.2,
corr_method: Literal["pearson", "kendall", "spearman"] = "kendall",
) -> tuple[Figure, pd.DataFrame]:
"""
Creates a correlation heatmap showing relationships between water quality parameters,
with additional information about data completeness.
Parameters
----------
df : pd.DataFrame
Input DataFrame containing water quality measurements. Must have columns:
- Org_Analyte_Name: Name of the analyte
- Org_Result_Value: Measurement value
- Activity_Start_Date_Time: Timestamp of measurement
- Reporting_Year: Year of measurement
- Station_Number: Monitoring station identifier
- Name: Station name
- Sample_Position: Sample depth position (e.g., "Surface", "Bottom")
analyte_names : list[str]
List of analyte names to include in correlation analysis
subset_by : str
Column name used for subsetting the data (e.g., "Sector", "Waterbody_Class")
subset : str
Value within subset_by column to filter data (e.g., specific sector name)
filter_by : str
Sample position filter ("Surface", "Bottom", or "All")
threshold : float, default=0.2
Minimum data completeness threshold (0-1). Parameters with completeness below
this threshold will be excluded from correlation analysis but listed in footnote.
corr_method : {"pearson", "kendall", "spearman"}, default="kendall"
Method of correlation to use. Options are:
- "pearson": Standard correlation coefficient
- "kendall": Kendall Tau correlation coefficient
- "spearman": Spearman rank correlation coefficient
Returns
-------
tuple[Figure, pd.DataFrame]
- Figure: Matplotlib figure containing:
- Correlation heatmap with values
- Title showing subset and sample size
- Footnote listing excluded parameters
- DataFrame: Pivot table of filtered data used for correlation analysis
Notes
-----
- Uses abbreviated parameter names for cleaner display (e.g., "DO" for "Dissolved Oxygen")
- Masks upper triangle of correlation matrix
- Colors correlations using RdBu_r colormap centered at 0
- Includes data completeness information in footnote
- Caches results using streamlit cache decorator
"""
# Constants from style guide
GREY30 = "#4d4d4d" # Dark grey for titles
GREY40 = "#666666" # Medium grey for axes and labels
measured_params = (
df[df["Org_Analyte_Name"].isin(analyte_names)]
.groupby("Org_Analyte_Name", observed=True)
.size()
)
# Create pivot table only for measured parameters that were requested
pivot_df = df[
df["Org_Analyte_Name"].isin(set(measured_params.index) & set(analyte_names))
].pivot_table(
index="Activity_Start_Date_Time",
columns="Org_Analyte_Name",
values="Org_Result_Value",
observed=False,
)
name_mapping = {
"Depth, Secchi Disk Depth": "Secchi Depth",
"Dissolved Oxygen": "DO",
"Fecal Coliform (MPN)": "Fecal Coliform",
"Total Nitrogen": "TN",
"Total Phosphorus": "TP",
}
# Calculate completeness based on number of measurements
completeness = {}
for param in measured_params.index:
if param in analyte_names and param in pivot_df.columns:
total_measurements = measured_params[param]
# Use original name to get values from pivot_df
valid_values = pivot_df[param].notna().sum()
# Store result using new name if it exists
new_name = name_mapping.get(param, param)
completeness[new_name] = valid_values / total_measurements
completeness = pd.Series(completeness)
pivot_df = pivot_df.rename(columns=name_mapping)
# Calculate data completeness for each parameter
completeness = pivot_df.notna().mean()
valid_params = completeness[completeness >= threshold].index
excluded_params = completeness[completeness < threshold]
# Filter pivot_df to only include parameters meeting the threshold
pivot_df = pivot_df[valid_params]
# Calculate correlation matrix
corr = pivot_df.corr(method=corr_method)
# Calculate sample size
n_samples = len(df)
fig = plt.figure(figsize=(3, 3.5))
# Adjust gridspec ratios and spacing
gs = fig.add_gridspec(
3,
1,
height_ratios=[
1, # Title space
4, # Heatmap
1.5, # Footnote
],
hspace=0.4,
)
# Add title axes, heatmap axes, and footnote axes
title_ax = fig.add_subplot(gs[0])
heatmap_ax = fig.add_subplot(gs[1])
footnote_ax = fig.add_subplot(gs[2])
# Create heatmap
mask = np.triu(np.ones_like(corr, dtype=bool))
heatmap = sns.heatmap(
corr,
mask=mask,
annot=True,
cmap="RdBu_r",
center=0,
vmin=-1,
vmax=1,
ax=heatmap_ax,
yticklabels=1,
cbar=True,
xticklabels=1,
annot_kws={"size": 5},
fmt=".2f",
)
# Style spines according to guide
for spine in heatmap_ax.spines.values():
spine.set_visible(False)
# Rotate x-axis labels and adjust their position
heatmap_ax.set_xticklabels(
heatmap_ax.get_xticklabels(),
rotation=45,
ha="right",
rotation_mode="anchor",
color=GREY40, # Style guide color for labels
)
# Update tick parameters - remove ticks but keep labels
heatmap_ax.tick_params(
axis="x",
pad=5,
labelsize=5,
length=0,
colors=GREY40, # Style guide color for labels
)
heatmap_ax.tick_params(
axis="y",
pad=5,
labelsize=5,
length=0,
colors=GREY40, # Style guide color for labels
)
# Fix the colorbar ticks warning by setting ticks first
colorbar = heatmap.figure.axes[-1] # type: ignore
ticks = colorbar.get_yticks()
colorbar.set_yticks(ticks)
tick_labels = [f"{x:>8.2f}" for x in ticks]
colorbar.set_yticklabels(
tick_labels,
size=5,
color=GREY40, # Style guide color for labels
)
# Rotate y-axis labels to horizontal
heatmap_ax.set_yticklabels(
heatmap_ax.get_yticklabels(),
rotation=0,
color=GREY40, # Style guide color for labels
)
# Remove axis labels
heatmap_ax.set_xlabel("")
heatmap_ax.set_ylabel("")
# Configure footnote axis
footnote_ax.set_frame_on(False) # Hide the frame
footnote_ax.set_xticks([]) # Remove x-ticks
footnote_ax.set_yticks([]) # Remove y-ticks
# Add footnote with adjusted position and style
if not excluded_params.empty:
footnote_text = "Excluded parameters (<{:.0%} data completeness):\n".format(
threshold
)
for param, completeness_val in excluded_params.items():
footnote_text += f" - {param}: {completeness_val:.1%} complete\n"
footnote_ax.text(
0.01,
0.40,
footnote_text.rstrip(),
ha="left",
va="center",
fontsize=5,
fontstyle="italic",
transform=footnote_ax.transAxes,
color=GREY40, # Style guide color
bbox=dict(
facecolor="white", alpha=0.8, edgecolor="none"
), # Style guide text box
)
title_ax.set_frame_on(False)
title_ax.set_xticks([])
title_ax.set_yticks([])
display_filter = "Surface and Bottom" if filter_by == "All" else filter_by
# Add year information to the subtitle
year_info = (
f"Reporting Year {df['Reporting_Year'].iloc[0]}"
if len(df["Reporting_Year"].unique()) == 1
else "All Years"
)
# Add titles with style guide typography
title_ax.text(
0.45,
0.8,
f"{subset_by}: {subset}",
ha="center",
va="center",
fontsize=8,
fontweight="bold",
transform=fig.transFigure,
color=GREY30, # Style guide title color
)
title_ax.text(
0.45,
0.75,
f"{display_filter}, {year_info} (n={n_samples:,})",
ha="center",
va="bottom",
fontsize=6,
fontstyle="italic",
transform=fig.transFigure,
color=GREY40, # Style guide color
)
# Replace tight_layout with more explicit spacing control
# First, calculate the figure bounds
fig.canvas.draw()
# Get the tight_bbox
renderer = fig.canvas.get_renderer() # type: ignore
fig.get_tightbbox(renderer)
# Adjust the subplot positions manually
fig.subplots_adjust(left=0.15, right=0.95, bottom=0.02, top=0.85, hspace=0.4)
return fig, pivot_df
def plot_np_ratios(df: pd.DataFrame) -> Figure:
"""
Create a visualization of N:P ratios over time and their distribution.
Parameters:
-----------
df : pd.DataFrame
Input dataframe containing nutrient measurements
Returns:
--------
Figure
Matplotlib figure containing time series and distribution plots
"""
# Style constants from guide
GREY30 = "#4d4d4d" # Dark grey for titles
GREY40 = "#666666" # Medium grey for axes and labels
# Create dataframe with N, P, and Sector information
nutrients_df = (
df[df["Org_Analyte_Name"].isin(["Total Nitrogen", "Total Phosphorus"])]
.pivot_table(
index=["Activity_Start_Date_Time", "Sector"],
columns="Org_Analyte_Name",
values="Org_Result_Value",
observed=True,
)
.reset_index()
)
# Calculate N:P ratio
nutrients_df["N:P Ratio"] = (
nutrients_df["Total Nitrogen"] / nutrients_df["Total Phosphorus"]
)
# Create figure with two subplots - following guide dimensions
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
# Time series plot with colors by sector
scatter = sns.scatterplot( # noqa: F841
data=nutrients_df,
x="Activity_Start_Date_Time",
y="N:P Ratio",
hue="Sector",
ax=ax1,
alpha=0.6,
s=20, # Smaller point size for better clarity
)
# Add Redfield ratio line with style guide specifications
ax1.axhline(
y=16,
color="red",
linestyle="--",
alpha=0.7,
linewidth=1.5,
label="Redfield Ratio (16:1)",
)
# Style spines according to guide
for ax in [ax1, ax2]:
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_color(GREY40)
ax.spines["bottom"].set_linewidth(0.5)
ax.spines["left"].set_visible(False)
# Remove tick marks but keep labels
ax.tick_params(axis="both", which="both", length=0, colors=GREY40)
# Add grid with matching style
ax.grid(True, axis="y", alpha=0.15, linestyle="-", color="gray")
# Style axis labels and titles according to guide
ax1.set_ylabel("N:P Ratio", color=GREY40, fontsize=10)
ax1.set_xlabel("Date", color=GREY40, fontsize=10)
ax1.set_title("N:P Ratio Over Time", color=GREY30, fontsize=10, pad=10)
# Adjust legend position and style
legend = ax1.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
frameon=False,
title="Sector",
)
legend.get_title().set_color(GREY40)
for text in legend.get_texts():
text.set_color(GREY40)
# Histogram plot with style guide specifications
sns.histplot(
x=nutrients_df["N:P Ratio"].dropna(),
ax=ax2,
alpha=0.6,
)
# Add Redfield ratio line to histogram
ax2.axvline(
x=16,
color="red",
linestyle="--",
alpha=0.7,
linewidth=1.5,
label="Redfield Ratio (16:1)",
)
# Style second plot labels
ax2.set_xlabel("N:P Ratio", color=GREY40, fontsize=10)
ax2.set_ylabel("Count", color=GREY40, fontsize=10)
ax2.set_title("Distribution of N:P Ratios", color=GREY30, fontsize=10, pad=10)
ax2.legend(frameon=False)
# Add sample size annotation
n_samples = len(nutrients_df.dropna(subset=["N:P Ratio"]))
stats_text = f"n = {n_samples:,}"
ax2.text(
0.02,
0.98,
stats_text,
transform=ax2.transAxes,
verticalalignment="top",
fontsize=8,
bbox=dict(facecolor="white", alpha=0.8, edgecolor="none"),
color=GREY40,
)
# Adjust layout to accommodate legend while maintaining proper spacing
plt.tight_layout()
# Adjust subplot positions manually for legend space
plt.subplots_adjust(right=0.85)
return fig
def altair_plot_np_ratios(df: pd.DataFrame) -> alt.VConcatChart:
# Create dataframe with N, P, and Sector information
nutrients_df = (
df[df["Org_Analyte_Name"].isin(["Total Nitrogen", "Total Phosphorus"])]
.pivot_table(
index=["Activity_Start_Date_Time", "Sector"],
columns="Org_Analyte_Name",
values="Org_Result_Value",
observed=True,
)
.reset_index()
)
# Calculate N:P ratio
nutrients_df["N:P Ratio"] = (
nutrients_df["Total Nitrogen"] / nutrients_df["Total Phosphorus"]
)
# Time series plot with colors by sector
time_series = (
alt.Chart(nutrients_df)
.mark_circle(size=60)
.encode(
x=alt.X(
"Activity_Start_Date_Time:T",
axis=alt.Axis(format="%Y", tickCount="year"),
title="Date",
),
y=alt.Y(r"N\:P Ratio:Q", title="N:P Ratio"),
color="Sector:N",
tooltip=[
alt.Tooltip("Activity_Start_Date_Time:T", title="Date"),
alt.Tooltip(r"N\:P Ratio:Q", format=".0f", title="N:P Ratio"),
alt.Tooltip("Sector:N", title="Sector"),
],
)
.properties(title="N:P Ratio Over Time", width=600, height=300)
.interactive()
)
# Add Redfield Ratio line
redfield_line = (
alt.Chart(pd.DataFrame({"y": [16]})).mark_rule(color="red").encode(y="y:Q")
)
# Histogram plot
histogram = (
alt.Chart(nutrients_df)
.mark_bar()
.encode(
x=alt.X(r"N\:P Ratio:Q", bin=alt.Bin(maxbins=30), title="N:P Ratio"),
y="count()",
tooltip=["count()"],
)
.properties(title="Distribution of N:P Ratios", width=600, height=300)
.interactive()
)
# Add Redfield Ratio line to histogram
redfield_hist_line = (
alt.Chart(pd.DataFrame({"x": [16]})).mark_rule(color="red").encode(x="x:Q")
)
# Combine plots
combined_chart = alt.vconcat(
time_series + redfield_line, histogram + redfield_hist_line
).resolve_scale(y="independent")
return combined_chart
def plot_do_temp_relationship(df: pd.DataFrame) -> Figure:
"""
Create a scatter plot of DO vs temperature with regression line using seaborn.
Parameters:
-----------
df : pd.DataFrame
Input dataframe containing DO and temperature measurements
Returns:
--------
Figure
Matplotlib figure containing the plot
"""
do_temp_data = (
df[df["Org_Analyte_Name"].isin(["Dissolved Oxygen", "Temperature, Water"])]
.pivot_table(
index=["Activity_Start_Date_Time", "Station_Number", "Sample_Position"],
columns="Org_Analyte_Name",
values="Org_Result_Value",
observed=True,
)
.reset_index()
.dropna(subset=["Dissolved Oxygen", "Temperature, Water"])
)
# Create custom color palette matching DO timeseries
custom_palette = {"Surface": "#5AA4D8", "Bottom": "#1B4B8A"}
# Create plot with regression line and adjust the hue order
g = sns.lmplot(
data=do_temp_data,
x="Temperature, Water",
y="Dissolved Oxygen",
hue="Sample_Position",
hue_order=["Bottom", "Surface"], # Plot 'Bottom' first
palette=custom_palette,
scatter_kws={"alpha": 0.5, "zorder": 2, "s": 20}, # Scatter plots at zorder=2
line_kws={"zorder": 3, "linewidth": 1}, # Trend lines at zorder=3
height=8,
aspect=1.5,
legend=False,
)
# Add DO threshold and set z-order
ax = g.axes[0, 0]
ax.axhline(
y=4.8, color="#FF8C00", linestyle="--", alpha=0.9, zorder=1, linewidth=1
) # Threshold line at zorder=1
ax.text(
ax.get_xlim()[0],
4.9,
" 4.8 mg/L DO threshold",
ha="left",
va="bottom",
color="#FF8C00",
alpha=0.9,
)
# Customize spines - only show bottom spine
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["bottom"].set_color("black")
ax.spines["bottom"].set_linewidth(0.5)
g.set_axis_labels("Water Temperature (°C)", "Dissolved Oxygen (mg/L)")
ax.set_title("Dissolved Oxygen vs Water Temperature", pad=20, fontsize=16)
# Adjust legend to show 'Surface' first
handles, labels = ax.get_legend_handles_labels()
# Reverse the order of handles and labels
handles = handles[::-1]
labels = labels[::-1]
ax.legend(
handles,
labels,
bbox_to_anchor=(1.0, 1.0),
loc="upper right",
frameon=False,
handletextpad=0.5,
)
# Add grid with matching style
ax.grid(True, axis="y", alpha=0.15, linestyle="-", color="gray")
# Remove tick marks but keep labels
ax.tick_params(axis="y", which="both", length=0)
# Set y-axis limits with some padding
ymin = max(int(min(do_temp_data["Dissolved Oxygen"].min(), 4.8) * 0.9) - 1, 0)
ymax = do_temp_data["Dissolved Oxygen"].max() * 1.1
ax.set_ylim(ymin, ymax)
yticks = np.arange(ymin, ymax, 2)
ax.set_yticks(yticks)
return g.figure
def plotly_plot_do_temp_relationship(df: pd.DataFrame) -> go.Figure:
"""
Create an interactive scatter plot of DO vs temperature with regression lines using Plotly.
Matches the style and features of the original matplotlib/seaborn plot.
Parameters:
-----------
df : pd.DataFrame
Input dataframe containing DO and temperature measurements
Returns:
--------
go.Figure
Plotly figure object
"""
# Prepare the data similarly to the original function
do_temp_data = (
df[df["Org_Analyte_Name"].isin(["Dissolved Oxygen", "Temperature, Water"])]
.pivot_table(
index=[
"Activity_Start_Date_Time",
"Station_Number",
"Sample_Position",
"Sector", # Added for tooltip
],
columns="Org_Analyte_Name",
values="Org_Result_Value",
observed=True,
)
.reset_index()
.dropna(subset=["Dissolved Oxygen", "Temperature, Water"])
)
# Create figure
fig = go.Figure()
# Colors matching seaborn's muted palette
colors = {"Surface": "#8da0cb", "Bottom": "#fc8d62"}
# Add scatter plots and regression lines for each position
for position in ["Surface", "Bottom"]:
pos_data = do_temp_data[do_temp_data["Sample_Position"] == position]
# Add scatter plot
fig.add_trace(
go.Scatter(
x=pos_data["Temperature, Water"],
y=pos_data["Dissolved Oxygen"],
mode="markers",
name=position,
marker=dict(color=colors[position], size=8, opacity=0.6),
hovertemplate=(
"Temperature: %{x:.1f}°C<br>"
"DO: %{y:.1f} mg/L<br>"
"Position: " + position + "<br>"
"Station: %{customdata[0]}<br>"
"Sector: %{customdata[1]}<br>"
"<extra></extra>"
),
customdata=pos_data[["Station_Number", "Sector"]],
)
)
# Calculate and add regression line
z = np.polyfit(pos_data["Temperature, Water"], pos_data["Dissolved Oxygen"], 1)
p = np.poly1d(z)
x_range = np.linspace(
pos_data["Temperature, Water"].min(),
pos_data["Temperature, Water"].max(),
100,
)
fig.add_trace(
go.Scatter(
x=x_range,
y=p(x_range),
mode="lines",
line=dict(color=colors[position], dash="dash"),
name=f"{position} Trend",
hovertemplate=None,
hoverinfo="skip",
showlegend=False,
)
)
# Add DO threshold line
fig.add_hline(
y=4.8,
line=dict(color="#FF8C00", width=1, dash="dash"),
opacity=0.5,
annotation_text="4.8 mg/L DO threshold",
annotation_position="left",
annotation=dict(
font=dict(color="#FF8C00", size=12),
xanchor="left",
yanchor="bottom",
opacity=0.8,
),
)
# Update layout
fig.update_layout(
title=dict(
text="Dissolved Oxygen vs Water Temperature",
x=0.5,
y=0.95,
xanchor="center",
yanchor="top",
font=dict(size=16),
),
xaxis_title="Water Temperature (°C)",
yaxis_title="Dissolved Oxygen (mg/L)",
legend_title="Sample Position",
legend=dict(
yanchor="top",
y=1,
xanchor="left",
x=1.05,
),
template="plotly_white",
width=800,
height=600,
showlegend=True,
)
# Update axes
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="rgba(128, 128, 128, 0.2)")
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="rgba(128, 128, 128, 0.2)")
return fig
def altair_plot_do_temp_relationship(df: pd.DataFrame) -> alt.LayerChart:
"""
Create an interactive scatter plot of DO vs temperature with regression lines using Altair.
Matches the style and features of the original matplotlib/seaborn plot.
Parameters:
-----------
df : pd.DataFrame
Input dataframe containing DO and temperature measurements
Returns:
--------
alt.Chart
Altair chart object
"""
# Prepare the data similarly to the original function
do_temp_data = (
df[df["Org_Analyte_Name"].isin(["Dissolved Oxygen", "Temperature, Water"])]
.pivot_table(
index=[
"Activity_Start_Date_Time",
"Station_Number",
"Sample_Position",
"Sector",
],
columns="Org_Analyte_Name",
values="Org_Result_Value",
observed=True,
)
.reset_index()
.dropna(subset=["Dissolved Oxygen", "Temperature, Water"])
)
# Create the base scatter plot
scatter = (
alt.Chart(do_temp_data)
.mark_circle(size=60, opacity=0.6)
.encode(
x=alt.X(
"Temperature, Water:Q",
title="Water Temperature (°C)",
scale=alt.Scale(zero=False),
),
y=alt.Y(
"Dissolved Oxygen:Q",
title="Dissolved Oxygen (mg/L)",
scale=alt.Scale(zero=False),
),
color=alt.Color(
"Sample_Position:N",
scale=alt.Scale(
domain=["Surface", "Bottom"],
range=["#8da0cb", "#fc8d62"], # Muted blue and orange
),
legend=alt.Legend(title="Sample Position"),
),
tooltip=[
alt.Tooltip("Temperature, Water:Q", title="Temperature", format=".1f"),
alt.Tooltip("Dissolved Oxygen:Q", title="DO", format=".1f"),
alt.Tooltip("Sample_Position:N", title="Position"),
alt.Tooltip("Sector:N", title="Sector"),
alt.Tooltip("Station_Number:N", title="Station"),
],
)
)
# Add regression lines for each Sample_Position
regression = (
scatter.transform_regression(
"Temperature, Water", "Dissolved Oxygen", groupby=["Sample_Position"]
)
.mark_line(size=2)
.encode(
color=alt.Color(
"Sample_Position:N",
scale=alt.Scale(
domain=["Surface", "Bottom"], range=["#8da0cb", "#fc8d62"]
),
)
)
)
# Create DO threshold line
threshold_df = pd.DataFrame({"y": [5]})
threshold_line = (
alt.Chart(threshold_df)
.mark_rule(strokeDash=[4, 4], color="red", opacity=0.5)
.encode(y="y:Q")
)
# Add threshold label
threshold_label = (
alt.Chart(
pd.DataFrame({"x": [do_temp_data["Temperature, Water"].min()], "y": [5.1]})
)
.mark_text(
align="left",
baseline="bottom",
color="red",
opacity=0.5,
text=" 5 mg/L DO threshold",
)
.encode(x="x:Q", y="y:Q")
)
# Combine all layers and configure
final_chart = (
alt.layer(scatter, regression, threshold_line, threshold_label)
.properties(
width=800,
height=750,
)
.configure_axis(grid=True, gridOpacity=0.3)
.interactive()
)
return final_chart