Hardy_Cross / hardy_cross_plot.py
razaali10's picture
Update hardy_cross_plot.py
cb823e9 verified
# hardy_cross_plot.py
"""
Headless plotting utility for the Hardy Cross app.
- Uses the non-GUI Agg backend (safe on HF Spaces / servers).
- Accepts a DataFrame of final flows (must include a 'pipe_id' column and a 'Q' column).
- If available, will use 'start_node' and 'end_node' columns to determine pipe direction.
- Otherwise, infers endpoints from the pipe_id (e.g., 'AB' -> 'A' -> 'B').
- Auto-lays out nodes on a circle unless explicit positions are provided.
Example:
import pandas as pd
from hardy_cross_plot import plot_hardy_cross_network
df = pd.DataFrame({
"pipe_id": ["AB", "BC", "CD", "DA", "AC"],
"Q": [0.08, 0.05, -0.06, -0.07, 0.02]
})
path = plot_hardy_cross_network(df, save_path="hardy_cross_network_result.png")
"""
import os
import math
import matplotlib
matplotlib.use("Agg") # critical: headless backend for servers/Spaces
import matplotlib.pyplot as plt
from typing import Dict, Tuple, Optional
import pandas as pd
import numpy as np
def _infer_endpoints_from_pipe_id(pipe_id: str) -> Tuple[str, str]:
"""
Try to infer start/end node names from a pipe_id.
Strategy:
- If it looks like 'AB' or 'A-B', use first and last alphanumeric chunks.
- Otherwise, fall back to first and last characters.
"""
if not isinstance(pipe_id, str) or len(pipe_id.strip()) == 0:
return ("?", "?")
s = pipe_id.strip()
# Try split on common separators first
for sep in ("-", "—", ">", "→"):
if sep in s:
parts = [p for p in s.split(sep) if p]
if len(parts) >= 2:
return (parts[0].strip(), parts[-1].strip())
# Extract alphanumeric chunks; pick first and last chunk
chunks = []
cur = []
for ch in s:
if ch.isalnum():
cur.append(ch)
elif cur:
chunks.append("".join(cur))
cur = []
if cur:
chunks.append("".join(cur))
if len(chunks) >= 2:
return (chunks[0], chunks[-1])
# Fallback: first and last character
return (s[0], s[-1])
def _collect_nodes(df: pd.DataFrame) -> pd.DataFrame:
"""
Ensure the dataframe has 'start_node' and 'end_node' columns.
If absent, infer endpoints from 'pipe_id'.
Returns a shallow copy with added columns if needed.
"""
df2 = df.copy()
has_start = "start_node" in df2.columns
has_end = "end_node" in df2.columns
if not has_start or not has_end:
starts = []
ends = []
for pid in df2["pipe_id"].astype(str):
s, e = _infer_endpoints_from_pipe_id(pid)
starts.append(s)
ends.append(e)
df2["start_node"] = starts
df2["end_node"] = ends
return df2
def _default_positions(nodes: list) -> Dict[str, Tuple[float, float]]:
"""
Place nodes on a circle for a clean, general layout.
"""
n = max(len(nodes), 1)
R = 1.0 # radius
positions = {}
for i, node in enumerate(nodes):
theta = 2 * math.pi * i / n
x = R * math.cos(theta)
y = R * math.sin(theta)
positions[node] = (x, y)
return positions
def plot_hardy_cross_network(
final_flows: pd.DataFrame,
save_path: str = "hardy_cross_network_result.png",
node_positions: Optional[Dict[str, Tuple[float, float]]] = None,
figsize: Tuple[int, int] = (8, 5),
) -> str:
"""
Create a simple network plot with arrows indicating flow direction and magnitude.
Parameters
----------
final_flows : pd.DataFrame
Must contain at least 'pipe_id' and 'Q'.
If 'start_node' and 'end_node' exist, they are used; otherwise inferred from 'pipe_id'.
save_path : str
Where to save the PNG.
node_positions : dict, optional
Mapping {node: (x, y)}. If None, nodes are placed on a circle.
figsize : tuple
Matplotlib figure size in inches.
Returns
-------
str
The path to the saved PNG (only if saved successfully).
"""
if final_flows is None or len(final_flows) == 0:
# Create an empty placeholder figure to avoid crashing the UI
fig, ax = plt.subplots(figsize=figsize)
ax.text(0.5, 0.5, "No data to plot", ha="center", va="center")
ax.axis("off")
fig.tight_layout()
fig.savefig(save_path, dpi=150)
plt.close(fig)
return save_path
if "pipe_id" not in final_flows.columns or "Q" not in final_flows.columns:
# Same: graceful placeholder
fig, ax = plt.subplots(figsize=figsize)
ax.text(0.5, 0.5, "Missing columns: require 'pipe_id' and 'Q'", ha="center", va="center")
ax.axis("off")
fig.tight_layout()
fig.savefig(save_path, dpi=150)
plt.close(fig)
return save_path
df = _collect_nodes(final_flows)
# Collect unique nodes and positions
nodes = sorted(set(df["start_node"].astype(str)) | set(df["end_node"].astype(str)))
if not node_positions:
node_positions = _default_positions(nodes)
# Scale arrows by relative magnitude for readability
# Normalize by max |Q|
q_abs = final_flows["Q"].abs().replace(0.0, np.nan)
q_max = np.nanmax(q_abs.values) if len(q_abs) else 1.0
if not np.isfinite(q_max) or q_max <= 0:
q_max = 1.0
fig, ax = plt.subplots(figsize=figsize)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("Hardy Cross Network Flows", fontsize=14, pad=12)
# Draw nodes
for node in nodes:
x, y = node_positions.get(node, (0.0, 0.0))
ax.plot(x, y, "ko", ms=5)
ax.text(x, y + 0.06, str(node), ha="center", va="bottom", fontsize=11, fontweight="bold")
# Draw pipes with arrows and labels
for _, row in df.iterrows():
pid = str(row["pipe_id"])
s = str(row["start_node"])
e = str(row["end_node"])
q = float(row["Q"])
x1, y1 = node_positions.get(s, (0.0, 0.0))
x2, y2 = node_positions.get(e, (0.0, 0.0))
# Direction: if Q >= 0, arrow from start->end; else reverse
if q >= 0:
sx, sy, tx, ty = x1, y1, x2, y2
else:
sx, sy, tx, ty = x2, y2, x1, y1
# Arrow style scaled by |Q|
width = 1.0 + 3.0 * (abs(q) / q_max) # line width in points
head_width = 0.02 + 0.06 * (abs(q) / q_max)
head_length = 0.03 + 0.09 * (abs(q) / q_max)
ax.annotate(
"",
xy=(tx, ty),
xytext=(sx, sy),
arrowprops=dict(
arrowstyle="-|>",
lw=width,
shrinkA=5,
shrinkB=5,
mutation_scale=12 + 30 * (abs(q) / q_max),
color="tab:blue",
),
)
# Label with flow value near the pipe midpoint
mx = 0.5 * (x1 + x2)
my = 0.5 * (y1 + y2)
ax.text(
mx, my,
f"{q:.3f} m³/s",
ha="center", va="center",
fontsize=10, color="tab:blue"
)
fig.tight_layout()
# Ensure directory exists and save
out_dir = os.path.dirname(save_path)
if out_dir and not os.path.isdir(out_dir):
os.makedirs(out_dir, exist_ok=True)
fig.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close(fig)
return save_path