File size: 11,951 Bytes
0edffc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

def d8_flow(z, tol=1e-3):
    z = np.asarray(z)
    H, W = z.shape
    dy = np.array([-1,  1,  0,  0, -1, -1,  1,  1], dtype=int)
    dx = np.array([ 0,  0, -1,  1, -1,  1, -1,  1], dtype=int)
    dist = np.array([1, 1, 1, 1, np.sqrt(2), np.sqrt(2), np.sqrt(2), np.sqrt(2)], dtype=z.dtype)

    zpad = np.pad(z, 1, mode='edge')
    nbrs = np.stack([zpad[1+dy[k]:1+dy[k]+H, 1+dx[k]:1+dx[k]+W] for k in range(8)], axis=0)
    slopes = (z[None] - nbrs) / dist[:, None, None]  # positive = downhill
    slopes[slopes < tol] = -np.inf

    # Ocean handling
    # - Centers that are NaN or <= 0 are sinks (ocean)
    # - Neighbors that are NaN or <= 0 act as ocean sinks: prefer routing into them
    center_ocean = np.isnan(z) | (z <= 0)
    neighbor_ocean = np.isnan(nbrs) | (nbrs <= 0)

    # Prepare two slope tensors:
    # 1) prefer_nan: prefers routing into NaN neighbors (treat as +inf slope)
    # 2) ignore_nan: ignores NaN neighbors (treat as -inf) to decide internal sinks
    prefer_ocean = slopes.copy()
    prefer_ocean[:, center_ocean] = -np.inf
    prefer_ocean[neighbor_ocean & (~center_ocean[None])] = np.inf

    ignore_ocean = slopes.copy()
    ignore_ocean[:, center_ocean] = -np.inf
    ignore_ocean[neighbor_ocean] = -np.inf

    # Chosen directions prefer draining into NaN neighbors (coast/ocean)
    kmax = np.argmax(prefer_ocean, axis=0)
    max_slope_prefer = np.take_along_axis(prefer_ocean, kmax[None], axis=0)[0]

    # is_sink: true only if center is NaN OR there is no downhill route ignoring NaNs
    max_slope_ignore = np.take_along_axis(ignore_ocean, np.argmax(ignore_ocean, axis=0)[None], axis=0)[0]
    has_ocean_neighbor = np.any(neighbor_ocean, axis=0)
    is_sink = center_ocean | ((~has_ocean_neighbor) & (~np.isfinite(max_slope_ignore)))

    rr = np.clip(np.arange(H)[:, None] + dy[kmax], 0, H - 1)
    cc = np.clip(np.arange(W)[None, :] + dx[kmax], 0, W - 1)
    return rr, cc, is_sink, kmax

def flow_accumulation(z, rr, cc, is_sink):
    H, W = z.shape
    invalid = np.isnan(z) | (z <= 0)
    # Initialize with ones for valid cells only
    A = np.zeros((H, W), dtype=np.float32)
    A[~invalid] = 1.0

    # Process cells from high to low elevation, ignoring NaNs
    flat_idx = np.flatnonzero(~invalid)
    if flat_idx.size:
        vals = z.ravel()[flat_idx]
        order = flat_idx[np.argsort(vals)[::-1]]
        r, c = order // W, order % W
        for i, j in zip(r, c):
            if not is_sink[i, j]:
                ti, tj = rr[i, j], cc[i, j]
                if not invalid[ti, tj]:
                    A[ti, tj] += A[i, j]
    return A

def plot_flow_indicator(z, max_pool_kernel=1):
    z = np.asarray(z)
    rr, cc, is_sink, kmax = d8_flow(z)
    A = flow_accumulation(z, rr, cc, is_sink)
    # Ensure ocean (NaN or <= 0) remain non-contributing in the indicator
    invalid = np.isnan(z) | (z <= 0)
    A[invalid] = 0.0
    
    # Perform max pooling on A, configurable by max_pool_kernel
    if max_pool_kernel > 1:
        # Downsampling max pool (non-overlapping, stride = kernel size)
        new_H = A.shape[0] // max_pool_kernel
        new_W = A.shape[1] // max_pool_kernel
        A = A[:new_H * max_pool_kernel, :new_W * max_pool_kernel]
        A = A.reshape(new_H, max_pool_kernel, new_W, max_pool_kernel)
        A = A.max(axis=(1, 3))
        
    return np.log1p(A)

def smooth_river_bumps(
    height,
    slope_thresh=50,     # below this, considered "flat"
    smooth_strength=0.3,    # fraction of smoothing applied
    iterations=3            # few iterations are enough
):
    """
    Removes small upslope bumps in rivers while preserving steep slopes.
    """
    h = height.copy().astype(np.float32)
    nan_mask = np.isnan(h)

    for _ in range(iterations):
        # Compute gradients on a NaN-filled-safe surface (treat NaNs as 0 for ops)
        h_safe = np.where(nan_mask, 0.0, h)
        grad_y, grad_x = np.gradient(h_safe)
        slope = np.sqrt(grad_x**2 + grad_y**2)

        # Build Laplacian ignoring NaN neighbors (4-neighbor)
        valid = ~nan_mask
        up_valid = np.roll(valid, 1, 0)
        dn_valid = np.roll(valid, -1, 0)
        lf_valid = np.roll(valid, 1, 1)
        rt_valid = np.roll(valid, -1, 1)

        up = np.where(up_valid, np.roll(h_safe, 1, 0), 0.0)
        dn = np.where(dn_valid, np.roll(h_safe, -1, 0), 0.0)
        lf = np.where(lf_valid, np.roll(h_safe, 1, 1), 0.0)
        rt = np.where(rt_valid, np.roll(h_safe, -1, 1), 0.0)

        neighbor_sum = up + dn + lf + rt
        neighbor_cnt = (
            up_valid.astype(np.float32)
            + dn_valid.astype(np.float32)
            + lf_valid.astype(np.float32)
            + rt_valid.astype(np.float32)
        )
        laplace = neighbor_sum - neighbor_cnt * h_safe
        laplace[nan_mask] = 0.0

        # Weight by (low slope) regions only; do not update NaN cells
        w = np.exp(- (slope / slope_thresh) ** 2)
        w[nan_mask] = 0.0

        # Apply selective smoothing, preserve NaNs
        h += smooth_strength * w * laplace
        h[nan_mask] = np.nan

    return h

import heapq

def fill_depressions_priority_flood(
    height: np.ndarray,
    epsilon: float = 1e-3,      # tiny gradient injected across flats
    max_raise: float | None = None,  # H_max: maximum allowed basin fill depth
    connectivity: int = 8,      # 4 or 8
    in_place: bool = False,
    nodata: float | None = None # treat NaNs (or this value) as barriers
) -> np.ndarray:
    """
    Priority-Flood selective depression fill.
    Fills pits only up to a maximum basin depth H_max ("max_raise").
    If the required fill depth exceeds H_max, the basin is left as a true
    depression (no further raising). Epsilon ensures drainage across flats.

    Args:
        height: 2D elevation array.
        epsilon: Small increment to ensure drainage across flats.
        connectivity: 4 or 8-neighbor graph.
        in_place: Modify input array in place if True.
        nodata: If provided, cells equal to this value are treated as invalid.
                NaNs are always treated as invalid.

    Returns:
        Filled elevation array (same shape).
    """
    h = height if in_place else height.copy()
    h = h.astype(np.float32, copy=False)
    # Preserve original heights; needed to track basin minima
    base = height.astype(np.float32, copy=False).copy()
    H, W = h.shape

    if nodata is None:
        ocean = np.isnan(h) | (h <= 0)
    else:
        ocean = np.isnan(h) | (h <= 0) | (h == nodata)
    invalid = ocean

    visited = np.zeros((H, W), dtype=bool)
    # Track the minimum original elevation encountered along the flood path
    # to each cell; used to measure basin fill depth relative to its minimum
    basin_min = np.full((H, W), np.inf, dtype=np.float32)
    heap: list[tuple[float, int, int]] = []

    if connectivity == 4:
        nbrs = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    else:
        nbrs = [(-1, 0), (1, 0), (0, -1), (0, 1),
                (-1, -1), (-1, 1), (1, -1), (1, 1)]

    # Seed with valid outer border cells
    for i in range(H):
        for j in (0, W - 1):
            if not invalid[i, j] and not visited[i, j]:
                heapq.heappush(heap, (h[i, j], i, j))
                visited[i, j] = True
                basin_min[i, j] = base[i, j]
    for j in range(W):
        for i in (0, H - 1):
            if not invalid[i, j] and not visited[i, j]:
                heapq.heappush(heap, (h[i, j], i, j))
                visited[i, j] = True
                basin_min[i, j] = base[i, j]

    # Also seed coast-adjacent valid cells (adjacent to ocean) as outlets
    if connectivity == 4:
        nbrs_seed = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    else:
        nbrs_seed = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]

    for r in range(H):
        for c in range(W):
            if invalid[r, c] or visited[r, c]:
                continue
            # If any neighbor is ocean, treat this as coastal outlet seed
            coastal = False
            for dr, dc in nbrs_seed:
                nr, nc = r + dr, c + dc
                if nr < 0 or nr >= H or nc < 0 or nc >= W:
                    continue
                if ocean[nr, nc]:
                    coastal = True
                    break
            if coastal:
                elev_seed = max(h[r, c], 0.0)
                heapq.heappush(heap, (elev_seed, r, c))
                visited[r, c] = True
                basin_min[r, c] = base[r, c]

    # Priority-Flood
    while heap:
        elev, r, c = heapq.heappop(heap)
        bm_cur = basin_min[r, c]
        for dr, dc in nbrs:
            nr, nc = r + dr, c + dc
            if nr < 0 or nr >= H or nc < 0 or nc >= W:
                continue
            if visited[nr, nc] or invalid[nr, nc]:
                continue

            ne = h[nr, nc]
            # Propagate basin minimum along the flood path
            bm_next = bm_cur if base[nr, nc] >= bm_cur else base[nr, nc]
            if ne <= elev:
                # Selective fill: stop raising if basin depth exceeds H_max
                if (max_raise is not None) and (elev - bm_cur >= max_raise):
                    heapq.heappush(heap, (ne, nr, nc))
                else:
                    new_e = elev + epsilon
                    # Ensure we never exceed the allowed basin depth
                    if max_raise is not None:
                        max_level = bm_cur + max_raise
                        if new_e > max_level:
                            new_e = max_level
                    if new_e > ne:
                        h[nr, nc] = new_e
                    heapq.heappush(heap, (h[nr, nc], nr, nc))
            else:
                heapq.heappush(heap, (ne, nr, nc))
            visited[nr, nc] = True
            basin_min[nr, nc] = bm_next

    return h

def local_baseline_temperature_torch(
    T: torch.Tensor,
    e: torch.Tensor,
    win: int = 3,
    beta_clip=(-0.012, 0.0),   # °C per meter
    fallback_beta=-0.0065,     # °C per meter
    eps=1e-6,
    fallback_threshold=0.3
):
    """
    Estimate local sea-level baseline temperature using a windowed regression.

    Args:
        T, e: 2D tensors (H, W) or batched (B, 1, H, W) of temperature [°C] and elevation [m].
        win: window size (odd integer)
        beta_clip: allowed lapse-rate range (°C/m)
        fallback_beta: used if local elevation variance ~ 0
        eps: small constant for stability

    Returns:
        T_sea: local baseline temperature map (B, 1, H-(win-1), W-(win-1))
        beta:   local lapse-rate map      (same shape)
    """
    if T.ndim == 2:
        T = T.unsqueeze(0).unsqueeze(0)
        e = e.unsqueeze(0).unsqueeze(0)
    elif T.ndim == 3:
        T = T.unsqueeze(1)
        e = e.unsqueeze(1)

    # Land mask (1 = land, 0 = ocean)
    w = (e > 0).float()

    # Compute weighted means with valid convolution (no padding)
    def wavg(x):
        num = F.avg_pool2d(x * w, win, stride=1, padding=0)
        den = F.avg_pool2d(w,      win, stride=1, padding=0)
        return num / (den + eps), den

    mu_T, sum_w = wavg(T)
    mu_e, _ = wavg(e)
    mu_e2, _ = wavg(e * e)
    mu_eT, _ = wavg(e * T)

    var_e  = mu_e2 - mu_e**2
    cov_eT = mu_eT - mu_e * mu_T

    # Local slope β (°C per meter)
    beta = cov_eT / (var_e + eps)

    # Flat or water-dominated windows → fallback β
    invalid = (var_e < 1.0) | (sum_w < fallback_threshold)  # <30% land
    beta = torch.where(invalid, torch.tensor(fallback_beta, device=beta.device), beta)
    beta = torch.clamp(beta, beta_clip[0], beta_clip[1])

    # Sea-level baseline using raw T and e (no averaging); crop to valid region
    pad = (win - 1) // 2
    T_c = T[:, :, pad:-pad, pad:-pad]
    e_c = e[:, :, pad:-pad, pad:-pad]

    T_sea = T_c - beta * e_c

    return T_sea.squeeze(1), beta.squeeze(1)