ViT-G β€” ViT-Gated Transformer for Streamflow Forecasting (Arizona)

A Transformer streamflow model that fuses meteorological time series with ViT-DINOv2 satellite embeddings through a learnable per-timestep gate. The gate conditions on weather, compressed ViT features, and a staleness signal (how many days since the ViT token last changed), so the network learns when to trust the satellite signal instead of treating it as a constant nuisance.

It warm-starts from the TNFR baseline (kdahal/tnfr-streamflow-arizona), freezes the 40-year-trained Transformer backbone, and trains only the compressor + gate. This was the strongest single model on natural, monsoon-driven basins in the TADA study.

Quick start

from huggingface_hub import snapshot_download
import sys, json, torch
from safetensors.torch import load_file

local = snapshot_download(repo_id="kdahal/vitg-streamflow-arizona")
sys.path.insert(0, local)

from model_code import Model

cfg = json.load(open(f"{local}/config.json"))
model = Model(
    enc_in=cfg["enc_in"], dec_in=cfg["dec_in"], c_out=cfg["c_out"], out_len=cfg["out_len"],
    d_model=cfg["d_model"], n_heads=cfg["n_heads"],
    e_layers=cfg["e_layers"], d_layers=cfg["d_layers"], d_ff=cfg["d_ff"],
    dropout=cfg["dropout"], embed=cfg["embed"], freq=cfg["freq"],
    activation=cfg["activation"], output_attention=cfg["output_attention"], mix=cfg["mix"],
)
model.load_state_dict(load_file(f"{local}/model.safetensors"))
model.eval()

Bundled demo:

pip install -r requirements.txt
python inference.py

Inputs

Name Shape Dtype Meaning
x_enc (B, 365, 634) float32 `[weather(250)
x_mark_enc (B, 365, 4) float32 Time marks (unused; pass zeros for daily data)
x_dec (B, 180+30, 1) float32 Decoder seed: 180 days history + 30 zeros
x_mark_dec (B, 180+30, 4) float32 Decoder time marks

The ViT segment should be DINOv2 ViT-small-14 CLS embeddings (dim 384) from your catchment-centered Sentinel-2 tile for each day; repeat the last-available embedding when no new tile is produced β€” the model's staleness logic handles the carry.

Output: tuple (pred, attns) where pred has shape (B, 30, 1) because output_attention=True. Use out[0] for the forecast.

Architecture

x_enc ─ split ┬─ weather (250) ──────────────────────────────┐
              β”‚                                              β”œβ”€ concat ─► Transformer (TADA) ─► pred
              └─ ViT (384) ─► compressor (384β†’128β†’32) ─► * β”€β”€β”˜
                                                    β–²
                             gate(weather | compressed | staleness) β”€β”˜
  • Transformer backbone: 3 enc / 2 dec layers, 8 heads, d_model=d_ff=256, dropout=0.2
  • Compressor: 2-layer MLP with GELU (384 β†’ 128 β†’ 32)
  • Gate: 2-layer MLP with sigmoid output (283 β†’ 64 β†’ 32); conditions on weather + compressed ViT + staleness
  • Staleness signal: per-timestep count of consecutive unchanged ViT tokens, divided by 30
  • 5.36M parameters total, 21 MB on disk (only gate + compressor are trainable once the baseline is loaded)

Training

  1. Warm-start the inner Transformer from the TNFR 40-year baseline (seed 4).
  2. Freeze the Transformer; train only the compressor and gate.
  3. Loss: Masked Huber (Ξ΄=1.0); AdamW LR=1e-3, dropout 0.2, 4 epochs, OneCycle.
  4. 5 seeds; this checkpoint is seed 1.

Performance

Strongest single model on natural / monsoon-dominated Arizona basins in the TADA study β€” see paper for median NSE/KGE tables by regulation tier and season.

Normalization

Inputs to this model must be z-score normalized with the frozen training statistics β€” not statistics computed from your own data. The stats for ViT-G include the 384 ViT-DINOv2 feature dimensions (so this file is ~13 MB, larger than the base variant). Shipped as locked_stats.json.

import json, numpy as np
stats = json.load(open("locked_stats.json"))
site  = "09379025"
means = stats["basin_dynamic_stats"][site]["means"]       # includes vit_0 … vit_383
stds  = stats["basin_dynamic_stats"][site]["stds"]

# apply to a daily dataframe `df` with columns matching feature names
for col in means:
    df[col] = (df[col] - means[col]) / (stds[col] + 1e-6)

# un-normalize streamflow output
mean_q, std_q = means["streamflow"], stds["streamflow"]
y_real = np.maximum(y_pred.numpy() * std_q + mean_q, 0.0)

Top-level keys in locked_stats.json:

  • basin_dynamic_stats: per-basin means/stds for 424 features (40 met/hydro + 384 ViT dims)
  • global_static_stats: static catchment-attribute means/stds
  • static_data_lookup: per-basin raw static attribute values

The ViT dimensions in basin_dynamic_stats correspond to DINOv2 ViT-small-14 CLS embeddings of catchment-centered Sentinel-2 tiles.

Files

  • model.safetensors β€” weights (state_dict)
  • config.json β€” constructor kwargs (including weather_dim, vit_dim, compressed_dim)
  • locked_stats.json β€” frozen per-basin normalization statistics (315 Arizona basins, 424 dynamic features)
  • model_code/ β€” self-contained PyTorch source
  • inference.py β€” end-to-end demo

License

MIT.

Related repos

  • kdahal/tnfr-streamflow-arizona β€” time-series-only baseline
  • kdahal/cnn-streamflow-arizona β€” CNN hybrid (raster tiles instead of ViT embeddings)
Downloads last month
1
Safetensors
Model size
5.36M params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support