CNN (Hybrid) β€” ResNet18 + Transformer for Streamflow Forecasting (Arizona)

A hybrid streamflow model that fuses raw catchment raster tiles (NDVI, EVI, NDSI, SRTM stacked into a 4-channel image) with meteorological time series. A ResNet18 encodes the raster into a d_model-dim image token, which is prepended to the Transformer encoder's input sequence. The Transformer decoder then forecasts daily streamflow over the next 30 days.

This is the "end-to-end vision" counterpart to kdahal/vitg-streamflow-arizona, which uses pre-computed ViT embeddings instead of learning visual features from scratch.

Quick start

from huggingface_hub import snapshot_download
import sys, json, torch
from safetensors.torch import load_file

local = snapshot_download(repo_id="kdahal/cnn-streamflow-arizona")
sys.path.insert(0, local)

from model_code import Model

cfg = json.load(open(f"{local}/config.json"))
model = Model(
    enc_in=cfg["enc_in"], dec_in=cfg["dec_in"], c_out=cfg["c_out"], out_len=cfg["out_len"],
    d_model=cfg["d_model"], n_heads=cfg["n_heads"],
    e_layers=cfg["e_layers"], d_layers=cfg["d_layers"], d_ff=cfg["d_ff"],
    dropout=cfg["dropout"], embed=cfg["embed"], freq=cfg["freq"],
    activation=cfg["activation"], output_attention=cfg["output_attention"],
    mix=cfg["mix"], backbone=cfg["backbone"],
)
model.load_state_dict(load_file(f"{local}/model.safetensors"))
model.eval()

Bundled demo:

pip install -r requirements.txt
python inference.py

Inputs

Name Shape Dtype Meaning
x_enc (B, 365, 250) float32 Past 365 days of normalized meteorological + static features
x_mark_enc (B, 365, 4) float32 Time marks (unused; pass zeros)
x_dec (B, 180+30, 1) float32 Decoder seed
x_mark_dec (B, 180+30, 4) float32 Decoder time marks
x_img (B, 4, 224, 224) float32 Channels: (NDVI, EVI, NDSI, SRTM) catchment-centered tile

Output: (B, 30, 1) β€” 30-day-ahead streamflow (normalized units).

Architecture

x_img  ──► ResNet18 (in_chans=4, num_classes=0) ─► Linear(512β†’d_model) ─► img_token ┐
                                                                                    β–Ό
x_enc  ──► enc_embedding ─► [img_token | enc_seq] ──► encoder ─► decoder ─► pred
  • ResNet18: 4-channel input, no ImageNet pretraining (randomly initialized)
  • Transformer backbone: 3 enc / 2 dec layers, 8 heads, d_model=d_ff=256, GELU
  • Fusion: the CNN image embedding is prepended as an extra token at position 0 of the encoder input, so attention can freely mix visual + temporal information
  • 16.58M parameters, 66 MB on disk

Training

  • Loss: Masked Huber on the 30-day window
  • Optimizer: AdamW; LR, schedule, and augmentations as in the TADA paper
  • 5 seeds; this checkpoint is seed 1

Normalization

Inputs to this model must be z-score normalized with the frozen training statistics β€” not statistics computed from your own data. Shipped as locked_stats.json and identical to the file used for the TNFR baseline (the CNN consumes the same 250-dim meteorology + static-attribute tensor as TNFR, plus the raw raster x_img).

import json, numpy as np
stats = json.load(open("locked_stats.json"))
site  = "09379025"
means = stats["basin_dynamic_stats"][site]["means"]
stds  = stats["basin_dynamic_stats"][site]["stds"]

# apply to a daily dataframe `df` with columns matching feature names
for col in means:
    df[col] = (df[col] - means[col]) / (stds[col] + 1e-6)

# un-normalize streamflow output
mean_q, std_q = means["streamflow"], stds["streamflow"]
y_real = np.maximum(y_pred.numpy() * std_q + mean_q, 0.0)

Top-level keys in locked_stats.json:

  • basin_dynamic_stats: per-basin means/stds for 40 dynamic features
  • global_static_stats: static catchment-attribute means/stds
  • static_data_lookup: per-basin raw static attribute values

Note on x_img: the raster stack (NDVI, EVI, NDSI, SRTM) is not normalized through locked_stats.json β€” each channel was pre-scaled in the tile-preparation pipeline. Pass the raster in the same convention you produce your tiles; the ResNet18 is trained to be robust to the float ranges used.

Files

  • model.safetensors β€” weights (state_dict)
  • config.json β€” constructor kwargs (includes img_channels, img_size, backbone)
  • locked_stats.json β€” frozen per-basin normalization statistics (315 Arizona basins, 40 dynamic features)
  • model_code/ β€” self-contained PyTorch source (+ timm required for ResNet18)
  • inference.py β€” end-to-end demo

License

MIT.

Related repos

  • kdahal/tnfr-streamflow-arizona β€” time-series-only baseline
  • kdahal/vitg-streamflow-arizona β€” ViT-gated variant using pre-computed satellite embeddings
Downloads last month
2
Safetensors
Model size
16.6M params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support