Mike0021's picture
init: terrain diffusion demo Space
0edffc2 verified
Raw
History Blame Contribute Delete
12 kB
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
def d8_flow(z, tol=1e-3):
z = np.asarray(z)
H, W = z.shape
dy = np.array([-1, 1, 0, 0, -1, -1, 1, 1], dtype=int)
dx = np.array([ 0, 0, -1, 1, -1, 1, -1, 1], dtype=int)
dist = np.array([1, 1, 1, 1, np.sqrt(2), np.sqrt(2), np.sqrt(2), np.sqrt(2)], dtype=z.dtype)
zpad = np.pad(z, 1, mode='edge')
nbrs = np.stack([zpad[1+dy[k]:1+dy[k]+H, 1+dx[k]:1+dx[k]+W] for k in range(8)], axis=0)
slopes = (z[None] - nbrs) / dist[:, None, None] # positive = downhill
slopes[slopes < tol] = -np.inf
# Ocean handling
# - Centers that are NaN or <= 0 are sinks (ocean)
# - Neighbors that are NaN or <= 0 act as ocean sinks: prefer routing into them
center_ocean = np.isnan(z) | (z <= 0)
neighbor_ocean = np.isnan(nbrs) | (nbrs <= 0)
# Prepare two slope tensors:
# 1) prefer_nan: prefers routing into NaN neighbors (treat as +inf slope)
# 2) ignore_nan: ignores NaN neighbors (treat as -inf) to decide internal sinks
prefer_ocean = slopes.copy()
prefer_ocean[:, center_ocean] = -np.inf
prefer_ocean[neighbor_ocean & (~center_ocean[None])] = np.inf
ignore_ocean = slopes.copy()
ignore_ocean[:, center_ocean] = -np.inf
ignore_ocean[neighbor_ocean] = -np.inf
# Chosen directions prefer draining into NaN neighbors (coast/ocean)
kmax = np.argmax(prefer_ocean, axis=0)
max_slope_prefer = np.take_along_axis(prefer_ocean, kmax[None], axis=0)[0]
# is_sink: true only if center is NaN OR there is no downhill route ignoring NaNs
max_slope_ignore = np.take_along_axis(ignore_ocean, np.argmax(ignore_ocean, axis=0)[None], axis=0)[0]
has_ocean_neighbor = np.any(neighbor_ocean, axis=0)
is_sink = center_ocean | ((~has_ocean_neighbor) & (~np.isfinite(max_slope_ignore)))
rr = np.clip(np.arange(H)[:, None] + dy[kmax], 0, H - 1)
cc = np.clip(np.arange(W)[None, :] + dx[kmax], 0, W - 1)
return rr, cc, is_sink, kmax
def flow_accumulation(z, rr, cc, is_sink):
H, W = z.shape
invalid = np.isnan(z) | (z <= 0)
# Initialize with ones for valid cells only
A = np.zeros((H, W), dtype=np.float32)
A[~invalid] = 1.0
# Process cells from high to low elevation, ignoring NaNs
flat_idx = np.flatnonzero(~invalid)
if flat_idx.size:
vals = z.ravel()[flat_idx]
order = flat_idx[np.argsort(vals)[::-1]]
r, c = order // W, order % W
for i, j in zip(r, c):
if not is_sink[i, j]:
ti, tj = rr[i, j], cc[i, j]
if not invalid[ti, tj]:
A[ti, tj] += A[i, j]
return A
def plot_flow_indicator(z, max_pool_kernel=1):
z = np.asarray(z)
rr, cc, is_sink, kmax = d8_flow(z)
A = flow_accumulation(z, rr, cc, is_sink)
# Ensure ocean (NaN or <= 0) remain non-contributing in the indicator
invalid = np.isnan(z) | (z <= 0)
A[invalid] = 0.0
# Perform max pooling on A, configurable by max_pool_kernel
if max_pool_kernel > 1:
# Downsampling max pool (non-overlapping, stride = kernel size)
new_H = A.shape[0] // max_pool_kernel
new_W = A.shape[1] // max_pool_kernel
A = A[:new_H * max_pool_kernel, :new_W * max_pool_kernel]
A = A.reshape(new_H, max_pool_kernel, new_W, max_pool_kernel)
A = A.max(axis=(1, 3))
return np.log1p(A)
def smooth_river_bumps(
height,
slope_thresh=50, # below this, considered "flat"
smooth_strength=0.3, # fraction of smoothing applied
iterations=3 # few iterations are enough
):
"""
Removes small upslope bumps in rivers while preserving steep slopes.
"""
h = height.copy().astype(np.float32)
nan_mask = np.isnan(h)
for _ in range(iterations):
# Compute gradients on a NaN-filled-safe surface (treat NaNs as 0 for ops)
h_safe = np.where(nan_mask, 0.0, h)
grad_y, grad_x = np.gradient(h_safe)
slope = np.sqrt(grad_x**2 + grad_y**2)
# Build Laplacian ignoring NaN neighbors (4-neighbor)
valid = ~nan_mask
up_valid = np.roll(valid, 1, 0)
dn_valid = np.roll(valid, -1, 0)
lf_valid = np.roll(valid, 1, 1)
rt_valid = np.roll(valid, -1, 1)
up = np.where(up_valid, np.roll(h_safe, 1, 0), 0.0)
dn = np.where(dn_valid, np.roll(h_safe, -1, 0), 0.0)
lf = np.where(lf_valid, np.roll(h_safe, 1, 1), 0.0)
rt = np.where(rt_valid, np.roll(h_safe, -1, 1), 0.0)
neighbor_sum = up + dn + lf + rt
neighbor_cnt = (
up_valid.astype(np.float32)
+ dn_valid.astype(np.float32)
+ lf_valid.astype(np.float32)
+ rt_valid.astype(np.float32)
)
laplace = neighbor_sum - neighbor_cnt * h_safe
laplace[nan_mask] = 0.0
# Weight by (low slope) regions only; do not update NaN cells
w = np.exp(- (slope / slope_thresh) ** 2)
w[nan_mask] = 0.0
# Apply selective smoothing, preserve NaNs
h += smooth_strength * w * laplace
h[nan_mask] = np.nan
return h
import heapq
def fill_depressions_priority_flood(
height: np.ndarray,
epsilon: float = 1e-3, # tiny gradient injected across flats
max_raise: float | None = None, # H_max: maximum allowed basin fill depth
connectivity: int = 8, # 4 or 8
in_place: bool = False,
nodata: float | None = None # treat NaNs (or this value) as barriers
) -> np.ndarray:
"""
Priority-Flood selective depression fill.
Fills pits only up to a maximum basin depth H_max ("max_raise").
If the required fill depth exceeds H_max, the basin is left as a true
depression (no further raising). Epsilon ensures drainage across flats.
Args:
height: 2D elevation array.
epsilon: Small increment to ensure drainage across flats.
connectivity: 4 or 8-neighbor graph.
in_place: Modify input array in place if True.
nodata: If provided, cells equal to this value are treated as invalid.
NaNs are always treated as invalid.
Returns:
Filled elevation array (same shape).
"""
h = height if in_place else height.copy()
h = h.astype(np.float32, copy=False)
# Preserve original heights; needed to track basin minima
base = height.astype(np.float32, copy=False).copy()
H, W = h.shape
if nodata is None:
ocean = np.isnan(h) | (h <= 0)
else:
ocean = np.isnan(h) | (h <= 0) | (h == nodata)
invalid = ocean
visited = np.zeros((H, W), dtype=bool)
# Track the minimum original elevation encountered along the flood path
# to each cell; used to measure basin fill depth relative to its minimum
basin_min = np.full((H, W), np.inf, dtype=np.float32)
heap: list[tuple[float, int, int]] = []
if connectivity == 4:
nbrs = [(-1, 0), (1, 0), (0, -1), (0, 1)]
else:
nbrs = [(-1, 0), (1, 0), (0, -1), (0, 1),
(-1, -1), (-1, 1), (1, -1), (1, 1)]
# Seed with valid outer border cells
for i in range(H):
for j in (0, W - 1):
if not invalid[i, j] and not visited[i, j]:
heapq.heappush(heap, (h[i, j], i, j))
visited[i, j] = True
basin_min[i, j] = base[i, j]
for j in range(W):
for i in (0, H - 1):
if not invalid[i, j] and not visited[i, j]:
heapq.heappush(heap, (h[i, j], i, j))
visited[i, j] = True
basin_min[i, j] = base[i, j]
# Also seed coast-adjacent valid cells (adjacent to ocean) as outlets
if connectivity == 4:
nbrs_seed = [(-1, 0), (1, 0), (0, -1), (0, 1)]
else:
nbrs_seed = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]
for r in range(H):
for c in range(W):
if invalid[r, c] or visited[r, c]:
continue
# If any neighbor is ocean, treat this as coastal outlet seed
coastal = False
for dr, dc in nbrs_seed:
nr, nc = r + dr, c + dc
if nr < 0 or nr >= H or nc < 0 or nc >= W:
continue
if ocean[nr, nc]:
coastal = True
break
if coastal:
elev_seed = max(h[r, c], 0.0)
heapq.heappush(heap, (elev_seed, r, c))
visited[r, c] = True
basin_min[r, c] = base[r, c]
# Priority-Flood
while heap:
elev, r, c = heapq.heappop(heap)
bm_cur = basin_min[r, c]
for dr, dc in nbrs:
nr, nc = r + dr, c + dc
if nr < 0 or nr >= H or nc < 0 or nc >= W:
continue
if visited[nr, nc] or invalid[nr, nc]:
continue
ne = h[nr, nc]
# Propagate basin minimum along the flood path
bm_next = bm_cur if base[nr, nc] >= bm_cur else base[nr, nc]
if ne <= elev:
# Selective fill: stop raising if basin depth exceeds H_max
if (max_raise is not None) and (elev - bm_cur >= max_raise):
heapq.heappush(heap, (ne, nr, nc))
else:
new_e = elev + epsilon
# Ensure we never exceed the allowed basin depth
if max_raise is not None:
max_level = bm_cur + max_raise
if new_e > max_level:
new_e = max_level
if new_e > ne:
h[nr, nc] = new_e
heapq.heappush(heap, (h[nr, nc], nr, nc))
else:
heapq.heappush(heap, (ne, nr, nc))
visited[nr, nc] = True
basin_min[nr, nc] = bm_next
return h
def local_baseline_temperature_torch(
T: torch.Tensor,
e: torch.Tensor,
win: int = 3,
beta_clip=(-0.012, 0.0), # °C per meter
fallback_beta=-0.0065, # °C per meter
eps=1e-6,
fallback_threshold=0.3
):
"""
Estimate local sea-level baseline temperature using a windowed regression.
Args:
T, e: 2D tensors (H, W) or batched (B, 1, H, W) of temperature [°C] and elevation [m].
win: window size (odd integer)
beta_clip: allowed lapse-rate range (°C/m)
fallback_beta: used if local elevation variance ~ 0
eps: small constant for stability
Returns:
T_sea: local baseline temperature map (B, 1, H-(win-1), W-(win-1))
beta: local lapse-rate map (same shape)
"""
if T.ndim == 2:
T = T.unsqueeze(0).unsqueeze(0)
e = e.unsqueeze(0).unsqueeze(0)
elif T.ndim == 3:
T = T.unsqueeze(1)
e = e.unsqueeze(1)
# Land mask (1 = land, 0 = ocean)
w = (e > 0).float()
# Compute weighted means with valid convolution (no padding)
def wavg(x):
num = F.avg_pool2d(x * w, win, stride=1, padding=0)
den = F.avg_pool2d(w, win, stride=1, padding=0)
return num / (den + eps), den
mu_T, sum_w = wavg(T)
mu_e, _ = wavg(e)
mu_e2, _ = wavg(e * e)
mu_eT, _ = wavg(e * T)
var_e = mu_e2 - mu_e**2
cov_eT = mu_eT - mu_e * mu_T
# Local slope β (°C per meter)
beta = cov_eT / (var_e + eps)
# Flat or water-dominated windows → fallback β
invalid = (var_e < 1.0) | (sum_w < fallback_threshold) # <30% land
beta = torch.where(invalid, torch.tensor(fallback_beta, device=beta.device), beta)
beta = torch.clamp(beta, beta_clip[0], beta_clip[1])
# Sea-level baseline using raw T and e (no averaging); crop to valid region
pad = (win - 1) // 2
T_c = T[:, :, pad:-pad, pad:-pad]
e_c = e[:, :, pad:-pad, pad:-pad]
T_sea = T_c - beta * e_c
return T_sea.squeeze(1), beta.squeeze(1)