Flood Water Depth Prediction Model (U-Net)
A U-Net regression model that predicts flood water depth (meters) given terrain and climate inputs.
Architecture
- Model: U-Net with ResNet-34 encoder (ImageNet pretrained)
- Task: Pixel-wise regression β continuous water depth β₯ 0 meters
- Input: 9-channel raster (128Γ128 patches)
- Output: Single-channel depth map (meters)
Input Channels
| # | Channel | Unit |
|---|---|---|
| 0 | DEM (elevation) | meters |
| 1 | Slope | degrees |
| 2 | HAND (Height Above Nearest Drainage) | meters |
| 3 | TWI (Topographic Wetness Index) | unitless |
| 4 | Manning's roughness | unitless |
| 5 | Soil permeability | mm/hr |
| 6 | Rainfall intensity | mm/hr |
| 7 | Storm duration | hours |
| 8 | Return period | years |
Performance (Test Set)
| Metric | Value |
|---|---|
| RMSE | 1.0404 m |
| MAE | 0.4178 m |
| RΒ² | 0.7916 |
| Extent F1 | 0.7549 |
| CSI@0.5m | 0.6450 |
| CSI@1.0m | 0.6594 |
| CSI@2.0m | 0.6893 |
How to Use
1. Install dependencies
pip install torch segmentation-models-pytorch huggingface_hub rasterio numpy scipy
2. Load the model
import torch
import numpy as np
import json
import segmentation_models_pytorch as smp
from huggingface_hub import hf_hub_download
# Download model weights and normalization stats
weights_path = hf_hub_download("peyterho/flood-depth-unet", "model.pt")
stats_path = hf_hub_download("peyterho/flood-depth-unet", "stats.json")
# Build and load model
model = smp.Unet(
encoder_name="resnet34",
encoder_weights=None,
in_channels=9,
classes=1,
activation=None,
)
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
model.eval()
# Load normalization constants
with open(stats_path) as f:
stats = json.load(f)
means = np.array(stats["means"], dtype=np.float32).reshape(9, 1, 1)
stds = np.array(stats["stds"], dtype=np.float32).reshape(9, 1, 1)
3. Prepare your input
The model expects a 9-channel raster of shape (9, 128, 128). Six channels come from your terrain data, three you set yourself as your climate scenario:
# === YOUR TERRAIN DATA (channels 0β5) ===
# Each is a 2D numpy array of shape (128, 128) at ~30 m resolution
dem = ... # elevation in meters (from SRTM, Copernicus DEM, or LiDAR)
slope = ... # terrain slope in degrees
hand = ... # Height Above Nearest Drainage in meters
twi = ... # Topographic Wetness Index
manning = ... # Manning's roughness (0.03=smooth, 0.06=vegetated, 0.15=dense forest)
soil_perm = ... # soil permeability in mm/hr (5=clay, 15=loam, 50=sand)
# === YOUR CLIMATE SCENARIO (channels 6β8) ===
# These are scalar values broadcast to the full 128Γ128 grid
rainfall_intensity = 80.0 # mm/hr β how hard it's raining
storm_duration = 12.0 # hours β how long the storm lasts
return_period = 100.0 # years β rarity of the event (2=common, 500=extreme)
H, W = 128, 128
rain_ch = np.full((H, W), rainfall_intensity, dtype=np.float32)
dur_ch = np.full((H, W), storm_duration, dtype=np.float32)
rp_ch = np.full((H, W), return_period, dtype=np.float32)
# === Stack and normalize ===
raw = np.stack([dem, slope, hand, twi, manning, soil_perm,
rain_ch, dur_ch, rp_ch]) # shape: (9, 128, 128)
normalized = (raw - means) / (stds + 1e-8)
x = torch.from_numpy(normalized).unsqueeze(0).float() # (1, 9, 128, 128)
4. Predict flood depth
with torch.no_grad():
depth_map = torch.relu(model(x)).squeeze().numpy() # (128, 128), values in meters
# depth_map[i, j] = predicted water depth at that pixel
# 0.0 = dry, 1.5 = 1.5 meters of water, etc.
5. Compare climate scenarios
Hold the terrain fixed and vary the storm to see how depth changes:
scenarios = {
"Moderate storm": ( 40.0, 6.0, 10),
"Heavy storm": ( 80.0, 12.0, 50),
"Extreme (1-in-500)": (160.0, 24.0, 500),
}
for name, (rain, dur, rp) in scenarios.items():
raw[6, :, :] = rain
raw[7, :, :] = dur
raw[8, :, :] = rp
normed = (raw - means) / (stds + 1e-8)
x = torch.from_numpy(normed).unsqueeze(0).float()
with torch.no_grad():
depth = torch.relu(model(x)).squeeze().numpy()
print(f"{name}: max depth = {depth.max():.2f}m, "
f"flooded area = {(depth > 0.01).mean()*100:.1f}%")
Where to Get the Terrain Inputs
| Channel | Free source | How to compute |
|---|---|---|
| DEM | Copernicus DEM 30 m or SRTM | Download GeoTIFF for your area |
| Slope | Derived from DEM | np.degrees(np.arctan(np.sqrt(dx**2 + dy**2))) where dy, dx = np.gradient(dem, 30.0) |
| HAND | Derived from DEM | Use pysheds or whitebox (see below) |
| TWI | Derived from DEM | np.log(contributing_area / np.tan(slope_radians)) |
| Manning's n | ESA WorldCover land use β lookup table | Forest = 0.12, grass = 0.05, urban = 0.02, water = 0.03 |
| Soil perm. | SoilGrids or HWSD | Clay = 5, loam = 15, sand = 50 mm/hr |
Computing slope from a DEM
dy, dx = np.gradient(dem, 30.0) # 30 m pixel spacing
slope = np.degrees(np.arctan(np.sqrt(dx**2 + dy**2)))
Computing HAND from a DEM with pysheds
pip install pysheds
from pysheds.grid import Grid
grid = Grid.from_raster("dem.tif")
dem = grid.read_raster("dem.tif")
pit_filled = grid.fill_pits(dem)
flooded = grid.fill_depressions(pit_filled)
inflated = grid.resolve_flats(flooded)
fdir = grid.flowdir(inflated)
acc = grid.accumulation(fdir)
hand = grid.compute_hand(fdir, dem, acc > 500) # threshold for channel cells
Handling Areas Larger Than 128Γ128
For a real locale your DEM will be much bigger. Use sliding-window inference with overlap to avoid tile-edge artefacts:
def predict_large_area(model, raw_input, means, stds,
tile_size=128, overlap=32):
"""Predict flood depth over a large raster using tiled inference."""
C, H, W = raw_input.shape
depth_map = np.zeros((H, W), dtype=np.float32)
count_map = np.zeros((H, W), dtype=np.float32)
stride = tile_size - overlap
for y in range(0, H, stride):
for x in range(0, W, stride):
y1, y2 = y, min(y + tile_size, H)
x1, x2 = x, min(x + tile_size, W)
patch = raw_input[:, y1:y2, x1:x2]
# Pad if smaller than tile_size
ph, pw = patch.shape[1], patch.shape[2]
if ph < tile_size or pw < tile_size:
padded = np.zeros((C, tile_size, tile_size), dtype=np.float32)
padded[:, :ph, :pw] = patch
patch = padded
normed = (patch - means) / (stds + 1e-8)
t = torch.from_numpy(normed).unsqueeze(0).float()
with torch.no_grad():
out = torch.relu(model(t)).squeeze().numpy()
depth_map[y1:y2, x1:x2] += out[:ph, :pw]
count_map[y1:y2, x1:x2] += 1.0
return depth_map / np.maximum(count_map, 1.0)
Saving Results as a GeoTIFF
import rasterio
# Copy CRS and transform from your input DEM
with rasterio.open("dem.tif") as src:
meta = src.meta.copy()
meta.update(count=1, dtype="float32", nodata=-9999)
with rasterio.open("flood_depth.tif", "w", **meta) as dst:
dst.write(depth_map, 1)
# β open flood_depth.tif in QGIS, ArcGIS, or Google Earth Engine
Climate Conditioning
Vary input channels to model different climate scenarios:
- Rainfall intensity: 10β160 mm/hr (higher = heavier storms)
- Duration: 1β24 hours (longer = more accumulation)
- Return period: 2β500 years (higher = rarer, more severe events)
Training
- Data: Physics-informed synthetic (DEM + Manning's equation + HAND)
- Loss: Weighted MSE + L1 + Gradient loss
- Optimizer: AdamW (lr=0.001)
- Scheduler: Cosine Annealing
- Epochs: 50
- Parameters: ~24.5M
Limitations
- Trained on synthetic data β fine-tune on regional hydraulic model outputs (HEC-RAS, LISFLOOD-FP) for production use
- Simplified physics (no infrastructure, backwater effects, or 2D flow routing)
- Best used as a screening tool, pre-training backbone, or educational tool
TODO / Roadmap
Make the model production-ready
- Fine-tune on real hydraulic simulation data β run LISFLOOD-FP or HEC-RAS for a target locale and use those depth grids as training labels instead of the synthetic physics approximations
- Add SAR imagery channels (Sentinel-1 VV/VH) β lets the model learn from actual observed flood events, not just terrain features
- Swap backbone to a geospatial foundation model β replace ResNet-34 with Prithvi-EO-2.0 via TerraTorch for much better generalization, especially when labeled data is scarce
Improve what we have
- Train on larger patches (256Γ256 or 512Γ512) with GPU β captures wider drainage patterns and river context
- Generate more diverse synthetic terrain β add coastal, urban, and mountainous terrain types to the data generator
- Build a Gradio demo Space β let users upload a DEM and get a flood depth map in the browser
- Add temporal dimension β extend the model to predict depth evolution over time (multi-step forecast), not just peak depth
Validate on a real locale
- Download Copernicus DEM for a target area and compute slope, HAND, TWI
- Run inference on real terrain to see how the model performs outside synthetic data
- Compare predictions against known flood extents (e.g., Sen1Floods11 or Copernicus EMS rapid activations)
Reproduce or extend training
All scripts are in scripts/:
generate_synthetic_data.pyβ standalone data generationtrain_full_pipeline.pyβ end-to-end: data gen β train β evaluate β push to Hublogs/training.logβ full training log from the original run (50 epochs, best RMSE 0.94 m at epoch 33)
- Downloads last month
- 85