Flood Water Depth Prediction Model (U-Net)

A U-Net regression model that predicts flood water depth (meters) given terrain and climate inputs.

Architecture

  • Model: U-Net with ResNet-34 encoder (ImageNet pretrained)
  • Task: Pixel-wise regression β€” continuous water depth β‰₯ 0 meters
  • Input: 9-channel raster (128Γ—128 patches)
  • Output: Single-channel depth map (meters)

Input Channels

# Channel Unit
0 DEM (elevation) meters
1 Slope degrees
2 HAND (Height Above Nearest Drainage) meters
3 TWI (Topographic Wetness Index) unitless
4 Manning's roughness unitless
5 Soil permeability mm/hr
6 Rainfall intensity mm/hr
7 Storm duration hours
8 Return period years

Performance (Test Set)

Metric Value
RMSE 1.0404 m
MAE 0.4178 m
RΒ² 0.7916
Extent F1 0.7549
CSI@0.5m 0.6450
CSI@1.0m 0.6594
CSI@2.0m 0.6893

How to Use

1. Install dependencies

pip install torch segmentation-models-pytorch huggingface_hub rasterio numpy scipy

2. Load the model

import torch
import numpy as np
import json
import segmentation_models_pytorch as smp
from huggingface_hub import hf_hub_download

# Download model weights and normalization stats
weights_path = hf_hub_download("peyterho/flood-depth-unet", "model.pt")
stats_path   = hf_hub_download("peyterho/flood-depth-unet", "stats.json")

# Build and load model
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None,
    in_channels=9,
    classes=1,
    activation=None,
)
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
model.eval()

# Load normalization constants
with open(stats_path) as f:
    stats = json.load(f)

means = np.array(stats["means"], dtype=np.float32).reshape(9, 1, 1)
stds  = np.array(stats["stds"],  dtype=np.float32).reshape(9, 1, 1)

3. Prepare your input

The model expects a 9-channel raster of shape (9, 128, 128). Six channels come from your terrain data, three you set yourself as your climate scenario:

# === YOUR TERRAIN DATA (channels 0–5) ===
# Each is a 2D numpy array of shape (128, 128) at ~30 m resolution

dem       = ...  # elevation in meters (from SRTM, Copernicus DEM, or LiDAR)
slope     = ...  # terrain slope in degrees
hand      = ...  # Height Above Nearest Drainage in meters
twi       = ...  # Topographic Wetness Index
manning   = ...  # Manning's roughness (0.03=smooth, 0.06=vegetated, 0.15=dense forest)
soil_perm = ...  # soil permeability in mm/hr (5=clay, 15=loam, 50=sand)

# === YOUR CLIMATE SCENARIO (channels 6–8) ===
# These are scalar values broadcast to the full 128Γ—128 grid

rainfall_intensity = 80.0   # mm/hr  β€” how hard it's raining
storm_duration     = 12.0   # hours  β€” how long the storm lasts
return_period      = 100.0  # years  β€” rarity of the event (2=common, 500=extreme)

H, W = 128, 128
rain_ch = np.full((H, W), rainfall_intensity, dtype=np.float32)
dur_ch  = np.full((H, W), storm_duration,     dtype=np.float32)
rp_ch   = np.full((H, W), return_period,      dtype=np.float32)

# === Stack and normalize ===
raw = np.stack([dem, slope, hand, twi, manning, soil_perm,
                rain_ch, dur_ch, rp_ch])          # shape: (9, 128, 128)
normalized = (raw - means) / (stds + 1e-8)
x = torch.from_numpy(normalized).unsqueeze(0).float()  # (1, 9, 128, 128)

4. Predict flood depth

with torch.no_grad():
    depth_map = torch.relu(model(x)).squeeze().numpy()  # (128, 128), values in meters

# depth_map[i, j] = predicted water depth at that pixel
# 0.0 = dry, 1.5 = 1.5 meters of water, etc.

5. Compare climate scenarios

Hold the terrain fixed and vary the storm to see how depth changes:

scenarios = {
    "Moderate storm":     ( 40.0,  6.0,   10),
    "Heavy storm":        ( 80.0, 12.0,   50),
    "Extreme (1-in-500)": (160.0, 24.0,  500),
}

for name, (rain, dur, rp) in scenarios.items():
    raw[6, :, :] = rain
    raw[7, :, :] = dur
    raw[8, :, :] = rp
    normed = (raw - means) / (stds + 1e-8)
    x = torch.from_numpy(normed).unsqueeze(0).float()

    with torch.no_grad():
        depth = torch.relu(model(x)).squeeze().numpy()

    print(f"{name}: max depth = {depth.max():.2f}m, "
          f"flooded area = {(depth > 0.01).mean()*100:.1f}%")

Where to Get the Terrain Inputs

Channel Free source How to compute
DEM Copernicus DEM 30 m or SRTM Download GeoTIFF for your area
Slope Derived from DEM np.degrees(np.arctan(np.sqrt(dx**2 + dy**2))) where dy, dx = np.gradient(dem, 30.0)
HAND Derived from DEM Use pysheds or whitebox (see below)
TWI Derived from DEM np.log(contributing_area / np.tan(slope_radians))
Manning's n ESA WorldCover land use β†’ lookup table Forest = 0.12, grass = 0.05, urban = 0.02, water = 0.03
Soil perm. SoilGrids or HWSD Clay = 5, loam = 15, sand = 50 mm/hr

Computing slope from a DEM

dy, dx = np.gradient(dem, 30.0)          # 30 m pixel spacing
slope = np.degrees(np.arctan(np.sqrt(dx**2 + dy**2)))

Computing HAND from a DEM with pysheds

pip install pysheds
from pysheds.grid import Grid

grid = Grid.from_raster("dem.tif")
dem  = grid.read_raster("dem.tif")

pit_filled = grid.fill_pits(dem)
flooded    = grid.fill_depressions(pit_filled)
inflated   = grid.resolve_flats(flooded)
fdir       = grid.flowdir(inflated)
acc        = grid.accumulation(fdir)
hand       = grid.compute_hand(fdir, dem, acc > 500)  # threshold for channel cells

Handling Areas Larger Than 128Γ—128

For a real locale your DEM will be much bigger. Use sliding-window inference with overlap to avoid tile-edge artefacts:

def predict_large_area(model, raw_input, means, stds,
                       tile_size=128, overlap=32):
    """Predict flood depth over a large raster using tiled inference."""
    C, H, W = raw_input.shape
    depth_map = np.zeros((H, W), dtype=np.float32)
    count_map = np.zeros((H, W), dtype=np.float32)
    stride = tile_size - overlap

    for y in range(0, H, stride):
        for x in range(0, W, stride):
            y1, y2 = y, min(y + tile_size, H)
            x1, x2 = x, min(x + tile_size, W)
            patch = raw_input[:, y1:y2, x1:x2]

            # Pad if smaller than tile_size
            ph, pw = patch.shape[1], patch.shape[2]
            if ph < tile_size or pw < tile_size:
                padded = np.zeros((C, tile_size, tile_size), dtype=np.float32)
                padded[:, :ph, :pw] = patch
                patch = padded

            normed = (patch - means) / (stds + 1e-8)
            t = torch.from_numpy(normed).unsqueeze(0).float()

            with torch.no_grad():
                out = torch.relu(model(t)).squeeze().numpy()

            depth_map[y1:y2, x1:x2] += out[:ph, :pw]
            count_map[y1:y2, x1:x2] += 1.0

    return depth_map / np.maximum(count_map, 1.0)

Saving Results as a GeoTIFF

import rasterio

# Copy CRS and transform from your input DEM
with rasterio.open("dem.tif") as src:
    meta = src.meta.copy()

meta.update(count=1, dtype="float32", nodata=-9999)

with rasterio.open("flood_depth.tif", "w", **meta) as dst:
    dst.write(depth_map, 1)

# β†’ open flood_depth.tif in QGIS, ArcGIS, or Google Earth Engine

Climate Conditioning

Vary input channels to model different climate scenarios:

  • Rainfall intensity: 10–160 mm/hr (higher = heavier storms)
  • Duration: 1–24 hours (longer = more accumulation)
  • Return period: 2–500 years (higher = rarer, more severe events)

Training

  • Data: Physics-informed synthetic (DEM + Manning's equation + HAND)
  • Loss: Weighted MSE + L1 + Gradient loss
  • Optimizer: AdamW (lr=0.001)
  • Scheduler: Cosine Annealing
  • Epochs: 50
  • Parameters: ~24.5M

Limitations

  • Trained on synthetic data β€” fine-tune on regional hydraulic model outputs (HEC-RAS, LISFLOOD-FP) for production use
  • Simplified physics (no infrastructure, backwater effects, or 2D flow routing)
  • Best used as a screening tool, pre-training backbone, or educational tool

TODO / Roadmap

Make the model production-ready

  • Fine-tune on real hydraulic simulation data β€” run LISFLOOD-FP or HEC-RAS for a target locale and use those depth grids as training labels instead of the synthetic physics approximations
  • Add SAR imagery channels (Sentinel-1 VV/VH) β€” lets the model learn from actual observed flood events, not just terrain features
  • Swap backbone to a geospatial foundation model β€” replace ResNet-34 with Prithvi-EO-2.0 via TerraTorch for much better generalization, especially when labeled data is scarce

Improve what we have

  • Train on larger patches (256Γ—256 or 512Γ—512) with GPU β€” captures wider drainage patterns and river context
  • Generate more diverse synthetic terrain β€” add coastal, urban, and mountainous terrain types to the data generator
  • Build a Gradio demo Space β€” let users upload a DEM and get a flood depth map in the browser
  • Add temporal dimension β€” extend the model to predict depth evolution over time (multi-step forecast), not just peak depth

Validate on a real locale

  • Download Copernicus DEM for a target area and compute slope, HAND, TWI
  • Run inference on real terrain to see how the model performs outside synthetic data
  • Compare predictions against known flood extents (e.g., Sen1Floods11 or Copernicus EMS rapid activations)

Reproduce or extend training

All scripts are in scripts/:

Downloads last month
85
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support