CNN (Hybrid) β ResNet18 + Transformer for Streamflow Forecasting (Arizona)
A hybrid streamflow model that fuses raw catchment raster tiles (NDVI,
EVI, NDSI, SRTM stacked into a 4-channel image) with meteorological time
series. A ResNet18 encodes the raster into a d_model-dim image token, which
is prepended to the Transformer encoder's input sequence. The Transformer
decoder then forecasts daily streamflow over the next 30 days.
This is the "end-to-end vision" counterpart to kdahal/vitg-streamflow-arizona,
which uses pre-computed ViT embeddings instead of learning visual features
from scratch.
Quick start
from huggingface_hub import snapshot_download
import sys, json, torch
from safetensors.torch import load_file
local = snapshot_download(repo_id="kdahal/cnn-streamflow-arizona")
sys.path.insert(0, local)
from model_code import Model
cfg = json.load(open(f"{local}/config.json"))
model = Model(
enc_in=cfg["enc_in"], dec_in=cfg["dec_in"], c_out=cfg["c_out"], out_len=cfg["out_len"],
d_model=cfg["d_model"], n_heads=cfg["n_heads"],
e_layers=cfg["e_layers"], d_layers=cfg["d_layers"], d_ff=cfg["d_ff"],
dropout=cfg["dropout"], embed=cfg["embed"], freq=cfg["freq"],
activation=cfg["activation"], output_attention=cfg["output_attention"],
mix=cfg["mix"], backbone=cfg["backbone"],
)
model.load_state_dict(load_file(f"{local}/model.safetensors"))
model.eval()
Bundled demo:
pip install -r requirements.txt
python inference.py
Inputs
| Name | Shape | Dtype | Meaning |
|---|---|---|---|
x_enc |
(B, 365, 250) |
float32 | Past 365 days of normalized meteorological + static features |
x_mark_enc |
(B, 365, 4) |
float32 | Time marks (unused; pass zeros) |
x_dec |
(B, 180+30, 1) |
float32 | Decoder seed |
x_mark_dec |
(B, 180+30, 4) |
float32 | Decoder time marks |
x_img |
(B, 4, 224, 224) |
float32 | Channels: (NDVI, EVI, NDSI, SRTM) catchment-centered tile |
Output: (B, 30, 1) β 30-day-ahead streamflow (normalized units).
Architecture
x_img βββΊ ResNet18 (in_chans=4, num_classes=0) ββΊ Linear(512βd_model) ββΊ img_token β
βΌ
x_enc βββΊ enc_embedding ββΊ [img_token | enc_seq] βββΊ encoder ββΊ decoder ββΊ pred
- ResNet18: 4-channel input, no ImageNet pretraining (randomly initialized)
- Transformer backbone: 3 enc / 2 dec layers, 8 heads,
d_model=d_ff=256, GELU - Fusion: the CNN image embedding is prepended as an extra token at position 0 of the encoder input, so attention can freely mix visual + temporal information
- 16.58M parameters, 66 MB on disk
Training
- Loss: Masked Huber on the 30-day window
- Optimizer: AdamW; LR, schedule, and augmentations as in the TADA paper
- 5 seeds; this checkpoint is seed 1
Normalization
Inputs to this model must be z-score normalized with the frozen training
statistics β not statistics computed from your own data. Shipped as
locked_stats.json and identical to the file used
for the TNFR baseline (the CNN consumes the same 250-dim meteorology +
static-attribute tensor as TNFR, plus the raw raster x_img).
import json, numpy as np
stats = json.load(open("locked_stats.json"))
site = "09379025"
means = stats["basin_dynamic_stats"][site]["means"]
stds = stats["basin_dynamic_stats"][site]["stds"]
# apply to a daily dataframe `df` with columns matching feature names
for col in means:
df[col] = (df[col] - means[col]) / (stds[col] + 1e-6)
# un-normalize streamflow output
mean_q, std_q = means["streamflow"], stds["streamflow"]
y_real = np.maximum(y_pred.numpy() * std_q + mean_q, 0.0)
Top-level keys in locked_stats.json:
basin_dynamic_stats: per-basin means/stds for 40 dynamic featuresglobal_static_stats: static catchment-attribute means/stdsstatic_data_lookup: per-basin raw static attribute values
Note on x_img: the raster stack (NDVI, EVI, NDSI, SRTM) is not
normalized through locked_stats.json β each channel was pre-scaled in the
tile-preparation pipeline. Pass the raster in the same convention you produce
your tiles; the ResNet18 is trained to be robust to the float ranges used.
Files
model.safetensorsβ weights (state_dict)config.jsonβ constructor kwargs (includesimg_channels,img_size,backbone)locked_stats.jsonβ frozen per-basin normalization statistics (315 Arizona basins, 40 dynamic features)model_code/β self-contained PyTorch source (+timmrequired for ResNet18)inference.pyβ end-to-end demo
License
MIT.
Related repos
kdahal/tnfr-streamflow-arizonaβ time-series-only baselinekdahal/vitg-streamflow-arizonaβ ViT-gated variant using pre-computed satellite embeddings
- Downloads last month
- 2