File size: 7,314 Bytes
cb823e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b15e7ec
cb823e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa16b2c
cb823e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa16b2c
cb823e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# hardy_cross_plot.py
"""
Headless plotting utility for the Hardy Cross app.

- Uses the non-GUI Agg backend (safe on HF Spaces / servers).
- Accepts a DataFrame of final flows (must include a 'pipe_id' column and a 'Q' column).
- If available, will use 'start_node' and 'end_node' columns to determine pipe direction.
- Otherwise, infers endpoints from the pipe_id (e.g., 'AB' -> 'A' -> 'B').
- Auto-lays out nodes on a circle unless explicit positions are provided.

Example:
    import pandas as pd
    from hardy_cross_plot import plot_hardy_cross_network

    df = pd.DataFrame({
        "pipe_id": ["AB", "BC", "CD", "DA", "AC"],
        "Q": [0.08, 0.05, -0.06, -0.07, 0.02]
    })
    path = plot_hardy_cross_network(df, save_path="hardy_cross_network_result.png")
"""

import os
import math
import matplotlib
matplotlib.use("Agg")  # critical: headless backend for servers/Spaces
import matplotlib.pyplot as plt
from typing import Dict, Tuple, Optional

import pandas as pd
import numpy as np


def _infer_endpoints_from_pipe_id(pipe_id: str) -> Tuple[str, str]:
    """
    Try to infer start/end node names from a pipe_id.
    Strategy:
        - If it looks like 'AB' or 'A-B', use first and last alphanumeric chunks.
        - Otherwise, fall back to first and last characters.
    """
    if not isinstance(pipe_id, str) or len(pipe_id.strip()) == 0:
        return ("?", "?")

    s = pipe_id.strip()

    # Try split on common separators first
    for sep in ("-", "—", ">", "→"):
        if sep in s:
            parts = [p for p in s.split(sep) if p]
            if len(parts) >= 2:
                return (parts[0].strip(), parts[-1].strip())

    # Extract alphanumeric chunks; pick first and last chunk
    chunks = []
    cur = []
    for ch in s:
        if ch.isalnum():
            cur.append(ch)
        elif cur:
            chunks.append("".join(cur))
            cur = []
    if cur:
        chunks.append("".join(cur))
    if len(chunks) >= 2:
        return (chunks[0], chunks[-1])

    # Fallback: first and last character
    return (s[0], s[-1])


def _collect_nodes(df: pd.DataFrame) -> pd.DataFrame:
    """
    Ensure the dataframe has 'start_node' and 'end_node' columns.
    If absent, infer endpoints from 'pipe_id'.
    Returns a shallow copy with added columns if needed.
    """
    df2 = df.copy()
    has_start = "start_node" in df2.columns
    has_end = "end_node" in df2.columns

    if not has_start or not has_end:
        starts = []
        ends = []
        for pid in df2["pipe_id"].astype(str):
            s, e = _infer_endpoints_from_pipe_id(pid)
            starts.append(s)
            ends.append(e)
        df2["start_node"] = starts
        df2["end_node"] = ends

    return df2


def _default_positions(nodes: list) -> Dict[str, Tuple[float, float]]:
    """
    Place nodes on a circle for a clean, general layout.
    """
    n = max(len(nodes), 1)
    R = 1.0  # radius
    positions = {}
    for i, node in enumerate(nodes):
        theta = 2 * math.pi * i / n
        x = R * math.cos(theta)
        y = R * math.sin(theta)
        positions[node] = (x, y)
    return positions


def plot_hardy_cross_network(
    final_flows: pd.DataFrame,
    save_path: str = "hardy_cross_network_result.png",
    node_positions: Optional[Dict[str, Tuple[float, float]]] = None,
    figsize: Tuple[int, int] = (8, 5),
) -> str:
    """
    Create a simple network plot with arrows indicating flow direction and magnitude.

    Parameters
    ----------
    final_flows : pd.DataFrame
        Must contain at least 'pipe_id' and 'Q'.
        If 'start_node' and 'end_node' exist, they are used; otherwise inferred from 'pipe_id'.
    save_path : str
        Where to save the PNG.
    node_positions : dict, optional
        Mapping {node: (x, y)}. If None, nodes are placed on a circle.
    figsize : tuple
        Matplotlib figure size in inches.

    Returns
    -------
    str
        The path to the saved PNG (only if saved successfully).
    """
    if final_flows is None or len(final_flows) == 0:
        # Create an empty placeholder figure to avoid crashing the UI
        fig, ax = plt.subplots(figsize=figsize)
        ax.text(0.5, 0.5, "No data to plot", ha="center", va="center")
        ax.axis("off")
        fig.tight_layout()
        fig.savefig(save_path, dpi=150)
        plt.close(fig)
        return save_path

    if "pipe_id" not in final_flows.columns or "Q" not in final_flows.columns:
        # Same: graceful placeholder
        fig, ax = plt.subplots(figsize=figsize)
        ax.text(0.5, 0.5, "Missing columns: require 'pipe_id' and 'Q'", ha="center", va="center")
        ax.axis("off")
        fig.tight_layout()
        fig.savefig(save_path, dpi=150)
        plt.close(fig)
        return save_path

    df = _collect_nodes(final_flows)

    # Collect unique nodes and positions
    nodes = sorted(set(df["start_node"].astype(str)) | set(df["end_node"].astype(str)))
    if not node_positions:
        node_positions = _default_positions(nodes)

    # Scale arrows by relative magnitude for readability
    # Normalize by max |Q|
    q_abs = final_flows["Q"].abs().replace(0.0, np.nan)
    q_max = np.nanmax(q_abs.values) if len(q_abs) else 1.0
    if not np.isfinite(q_max) or q_max <= 0:
        q_max = 1.0

    fig, ax = plt.subplots(figsize=figsize)
    ax.set_aspect("equal")
    ax.axis("off")
    ax.set_title("Hardy Cross Network Flows", fontsize=14, pad=12)

    # Draw nodes
    for node in nodes:
        x, y = node_positions.get(node, (0.0, 0.0))
        ax.plot(x, y, "ko", ms=5)
        ax.text(x, y + 0.06, str(node), ha="center", va="bottom", fontsize=11, fontweight="bold")

    # Draw pipes with arrows and labels
    for _, row in df.iterrows():
        pid = str(row["pipe_id"])
        s = str(row["start_node"])
        e = str(row["end_node"])
        q = float(row["Q"])

        x1, y1 = node_positions.get(s, (0.0, 0.0))
        x2, y2 = node_positions.get(e, (0.0, 0.0))

        # Direction: if Q >= 0, arrow from start->end; else reverse
        if q >= 0:
            sx, sy, tx, ty = x1, y1, x2, y2
        else:
            sx, sy, tx, ty = x2, y2, x1, y1

        # Arrow style scaled by |Q|
        width = 1.0 + 3.0 * (abs(q) / q_max)  # line width in points
        head_width = 0.02 + 0.06 * (abs(q) / q_max)
        head_length = 0.03 + 0.09 * (abs(q) / q_max)

        ax.annotate(
            "",
            xy=(tx, ty),
            xytext=(sx, sy),
            arrowprops=dict(
                arrowstyle="-|>",
                lw=width,
                shrinkA=5,
                shrinkB=5,
                mutation_scale=12 + 30 * (abs(q) / q_max),
                color="tab:blue",
            ),
        )

        # Label with flow value near the pipe midpoint
        mx = 0.5 * (x1 + x2)
        my = 0.5 * (y1 + y2)
        ax.text(
            mx, my,
            f"{q:.3f} m³/s",
            ha="center", va="center",
            fontsize=10, color="tab:blue"
        )

    fig.tight_layout()

    # Ensure directory exists and save
    out_dir = os.path.dirname(save_path)
    if out_dir and not os.path.isdir(out_dir):
        os.makedirs(out_dir, exist_ok=True)

    fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)

    return save_path