ViT-G β ViT-Gated Transformer for Streamflow Forecasting (Arizona)
A Transformer streamflow model that fuses meteorological time series with ViT-DINOv2 satellite embeddings through a learnable per-timestep gate. The gate conditions on weather, compressed ViT features, and a staleness signal (how many days since the ViT token last changed), so the network learns when to trust the satellite signal instead of treating it as a constant nuisance.
It warm-starts from the TNFR baseline (kdahal/tnfr-streamflow-arizona),
freezes the 40-year-trained Transformer backbone, and trains only the
compressor + gate. This was the strongest single model on natural,
monsoon-driven basins in the TADA study.
Quick start
from huggingface_hub import snapshot_download
import sys, json, torch
from safetensors.torch import load_file
local = snapshot_download(repo_id="kdahal/vitg-streamflow-arizona")
sys.path.insert(0, local)
from model_code import Model
cfg = json.load(open(f"{local}/config.json"))
model = Model(
enc_in=cfg["enc_in"], dec_in=cfg["dec_in"], c_out=cfg["c_out"], out_len=cfg["out_len"],
d_model=cfg["d_model"], n_heads=cfg["n_heads"],
e_layers=cfg["e_layers"], d_layers=cfg["d_layers"], d_ff=cfg["d_ff"],
dropout=cfg["dropout"], embed=cfg["embed"], freq=cfg["freq"],
activation=cfg["activation"], output_attention=cfg["output_attention"], mix=cfg["mix"],
)
model.load_state_dict(load_file(f"{local}/model.safetensors"))
model.eval()
Bundled demo:
pip install -r requirements.txt
python inference.py
Inputs
| Name | Shape | Dtype | Meaning |
|---|---|---|---|
x_enc |
(B, 365, 634) |
float32 | `[weather(250) |
x_mark_enc |
(B, 365, 4) |
float32 | Time marks (unused; pass zeros for daily data) |
x_dec |
(B, 180+30, 1) |
float32 | Decoder seed: 180 days history + 30 zeros |
x_mark_dec |
(B, 180+30, 4) |
float32 | Decoder time marks |
The ViT segment should be DINOv2 ViT-small-14 CLS embeddings (dim 384) from your catchment-centered Sentinel-2 tile for each day; repeat the last-available embedding when no new tile is produced β the model's staleness logic handles the carry.
Output: tuple (pred, attns) where pred has shape (B, 30, 1) because
output_attention=True. Use out[0] for the forecast.
Architecture
x_enc β split β¬β weather (250) βββββββββββββββββββββββββββββββ
β ββ concat ββΊ Transformer (TADA) ββΊ pred
ββ ViT (384) ββΊ compressor (384β128β32) ββΊ * βββ
β²
gate(weather | compressed | staleness) ββ
- Transformer backbone: 3 enc / 2 dec layers, 8 heads,
d_model=d_ff=256,dropout=0.2 - Compressor: 2-layer MLP with GELU (384 β 128 β 32)
- Gate: 2-layer MLP with sigmoid output (283 β 64 β 32); conditions on weather + compressed ViT + staleness
- Staleness signal: per-timestep count of consecutive unchanged ViT tokens, divided by 30
- 5.36M parameters total, 21 MB on disk (only gate + compressor are trainable once the baseline is loaded)
Training
- Warm-start the inner Transformer from the TNFR 40-year baseline (seed 4).
- Freeze the Transformer; train only the compressor and gate.
- Loss: Masked Huber (Ξ΄=1.0); AdamW LR=1e-3, dropout 0.2, 4 epochs, OneCycle.
- 5 seeds; this checkpoint is seed 1.
Performance
Strongest single model on natural / monsoon-dominated Arizona basins in the TADA study β see paper for median NSE/KGE tables by regulation tier and season.
Normalization
Inputs to this model must be z-score normalized with the frozen training
statistics β not statistics computed from your own data. The stats for
ViT-G include the 384 ViT-DINOv2 feature dimensions (so this file is ~13 MB,
larger than the base variant). Shipped as locked_stats.json.
import json, numpy as np
stats = json.load(open("locked_stats.json"))
site = "09379025"
means = stats["basin_dynamic_stats"][site]["means"] # includes vit_0 β¦ vit_383
stds = stats["basin_dynamic_stats"][site]["stds"]
# apply to a daily dataframe `df` with columns matching feature names
for col in means:
df[col] = (df[col] - means[col]) / (stds[col] + 1e-6)
# un-normalize streamflow output
mean_q, std_q = means["streamflow"], stds["streamflow"]
y_real = np.maximum(y_pred.numpy() * std_q + mean_q, 0.0)
Top-level keys in locked_stats.json:
basin_dynamic_stats: per-basin means/stds for 424 features (40 met/hydro + 384 ViT dims)global_static_stats: static catchment-attribute means/stdsstatic_data_lookup: per-basin raw static attribute values
The ViT dimensions in basin_dynamic_stats correspond to DINOv2 ViT-small-14
CLS embeddings of catchment-centered Sentinel-2 tiles.
Files
model.safetensorsβ weights (state_dict)config.jsonβ constructor kwargs (includingweather_dim,vit_dim,compressed_dim)locked_stats.jsonβ frozen per-basin normalization statistics (315 Arizona basins, 424 dynamic features)model_code/β self-contained PyTorch sourceinference.pyβ end-to-end demo
License
MIT.
Related repos
kdahal/tnfr-streamflow-arizonaβ time-series-only baselinekdahal/cnn-streamflow-arizonaβ CNN hybrid (raster tiles instead of ViT embeddings)
- Downloads last month
- 1