| <div align="center"> |
|
|
| # ConvGRU-Ensemble |
|
|
| **Ensemble precipitation nowcasting using Convolutional GRU networks** |
|
|
| *Pretrained model for Italy:* ***IRENE*** — **I**talian **R**adar **E**nsemble **N**owcasting **E**xperiment |
|
|
| [](https://github.com/DSIP-FBK/ConvGRU-Ensemble/actions) |
| [](LICENSE) |
| [](https://python.org) |
| [](https://huggingface.co/it4lia/irene) |
|
|
| <br> |
|
|
| <table align="center" style="border: none; border-collapse: collapse;"> |
| <tr style="border: none;"> |
| <td colspan="2" align="center" style="border: none; padding: 10px;"><a href="https://it4lia-aifactory.eu"><img src="https://it4lia-aifactory.eu/wp-content/uploads/2025/05/logo-IT4LIA-AI-factory.svg" width="200" alt="IT4LIA AI-Factory"></a></td> |
| </tr> |
| <tr style="border: none;"> |
| <td align="center" style="border: none; padding: 10px;"><a href="https://www.fbk.eu"><img src="https://webvalley.fbk.eu/static/img/logos/fbk-logo-blue.png" width="120" alt="Fondazione Bruno Kessler"></a></td> |
| <td align="center" style="border: none; padding: 10px;"><a href="https://www.italiameteo.eu"><img src="https://it4lia-aifactory.eu/wp-content/uploads/2025/08/logo-italiameteo.svg" width="180" alt="ItaliaMeteo"></a></td> |
| </tr> |
| </table> |
|
|
| <br> |
|
|
| The model encodes past radar frames into multi-scale hidden states and decodes them into an **ensemble of probabilistic forecasts** by running the decoder multiple times with different noise inputs, trained with **CRPS loss**. |
|
|
| </div> |
|
|
| --- |
|
|
| ## Quick Start |
|
|
| <details open> |
| <summary><b>Load from HuggingFace Hub</b></summary> |
|
|
| ```python |
| from convgru_ensemble import RadarLightningModel |
| |
| model = RadarLightningModel.from_pretrained("it4lia/irene") |
| |
| import numpy as np |
| past = np.load("past_radar.npy") # rain rate in mm/h, shape (T_past, H, W) |
| forecasts = model.predict(past, forecast_steps=12, ensemble_size=10) |
| # forecasts.shape = (10, 12, H, W) — 10 members, 12 future steps, mm/h |
| ``` |
|
|
| </details> |
|
|
| <details> |
| <summary><b>CLI Inference</b></summary> |
|
|
| ```bash |
| convgru-ensemble predict \ |
| --input examples/sample_data.nc \ |
| --hub-repo it4lia/irene \ |
| --forecast-steps 12 \ |
| --ensemble-size 10 \ |
| --output predictions.nc |
| ``` |
|
|
| </details> |
|
|
| <details> |
| <summary><b>Serve via API</b></summary> |
|
|
| ```bash |
| # With Docker |
| docker compose up |
| |
| # Or directly |
| pip install convgru-ensemble[serve] |
| convgru-ensemble serve --hub-repo it4lia/irene --port 8000 |
| ``` |
|
|
| **Submit a forecast request:** |
|
|
| ```bash |
| # 4-hour forecast (4 steps × 1h) with 5 ensemble members |
| curl -X POST "http://localhost:8000/predict?forecast_steps=4&ensemble_size=5" \ |
| -F "file=@examples/sample_data.nc" \ |
| -o predictions.nc |
| |
| # Use default settings (12 steps, 10 members) |
| curl -X POST http://localhost:8000/predict \ |
| -F "file=@examples/sample_data.nc" -o predictions.nc |
| ``` |
|
|
| **Read the predictions:** |
|
|
| ```python |
| import xarray as xr |
| |
| ds = xr.open_dataset("predictions.nc") |
| print(ds.precipitation_forecast.shape) |
| # (5, 4, 1400, 1200) — ensemble_member, forecast_step, y, x |
| ``` |
|
|
| | Endpoint | Method | Description | |
| |---|---|---| |
| | `/health` | GET | Health check | |
| | `/model/info` | GET | Model metadata and hyperparameters | |
| | `/predict` | POST | Upload NetCDF, get ensemble forecast as NetCDF | |
|
|
| **`/predict` query parameters:** |
|
|
| | Parameter | Default | Description | |
| |---|---|---| |
| | `variable` | `RR` | Name of the rain rate variable in the NetCDF | |
| | `forecast_steps` | `12` | Number of future 5-min steps (1–48, i.e. max 4h) | |
| | `ensemble_size` | `10` | Number of ensemble members (1–10) | |
|
|
| The input NetCDF must contain a 3D variable `(T, H, W)` with rain rate in mm/h and at least 2 timesteps. |
|
|
| </details> |
|
|
| <details> |
| <summary><b>Fine-tune on your data</b></summary> |
|
|
| ```bash |
| pip install convgru-ensemble |
| # See "Training" section below |
| ``` |
|
|
| </details> |
|
|
| ## Setup |
|
|
| Requires Python >= 3.13. Uses [uv](https://github.com/astral-sh/uv) for dependency management. |
|
|
| ```bash |
| uv sync # core dependencies |
| uv sync --extra serve # + FastAPI serving |
| ``` |
|
|
| ## Data Preparation |
|
|
| The training pipeline expects a Zarr dataset with a rain rate variable `RR` indexed by `(time, x, y)`. |
|
|
| <details> |
| <summary><b>1. Filter valid datacubes</b></summary> |
|
|
| Scan the Zarr and find all space-time datacubes with fewer than `n_nan` NaN values: |
|
|
| ```bash |
| cd importance_sampler |
| uv run python filter_nan.py path/to/dataset.zarr \ |
| --start_date 2021-01-01 --end_date 2025-12-11 \ |
| --Dt 24 --w 256 --h 256 \ |
| --step_T 3 --step_X 16 --step_Y 16 \ |
| --n_nan 10000 --n_workers 8 |
| ``` |
|
|
| </details> |
|
|
| <details> |
| <summary><b>2. Importance sampling</b></summary> |
|
|
| Sample valid datacubes with higher probability for rainier events: |
|
|
| ```bash |
| uv run python sample_valid_datacubes.py path/to/dataset.zarr valid_datacubes_*.csv \ |
| --q_min 1e-4 --m 0.1 --n_workers 8 |
| ``` |
|
|
| A pre-sampled CSV is provided in [`importance_sampler/output/`](importance_sampler/output/). |
|
|
| </details> |
|
|
| ## Training |
|
|
| Training is configured via [Fiddle](https://github.com/google/fiddle). Run with defaults: |
|
|
| ```bash |
| uv run python -m convgru_ensemble.train |
| ``` |
|
|
| Override parameters from the command line: |
|
|
| ```bash |
| uv run python -m convgru_ensemble.train \ |
| --config config:experiment \ |
| --config set:model.num_blocks=5 \ |
| --config set:model.forecast_steps=12 \ |
| --config set:model.loss_class=crps \ |
| --config set:model.ensemble_size=2 \ |
| --config set:datamodule.batch_size=16 \ |
| --config set:trainer.max_epochs=100 |
| ``` |
|
|
| Monitor with TensorBoard: `uv run tensorboard --logdir logs/` |
|
|
| | Parameter | Description | Default | |
| |---|---|---| |
| | `model.num_blocks` | Encoder/decoder depth | `5` | |
| | `model.forecast_steps` | Future steps to predict | `12` | |
| | `model.ensemble_size` | Ensemble members during training | `2` | |
| | `model.loss_class` | Loss function (`mse`, `mae`, `crps`, `afcrps`) | `crps` | |
| | `model.masked_loss` | Mask NaN regions in loss | `True` | |
| | `datamodule.steps` | Total timesteps per sample (past + future) | `18` | |
| | `datamodule.batch_size` | Batch size | `16` | |
|
|
| ## Architecture |
|
|
| ``` |
| Input (B, T_past, 1, H, W) |
| | |
| v |
| +--------------------------+ |
| | Encoder | ConvGRU + PixelUnshuffle (x num_blocks) |
| | Spatial dims halve at | Channels: 1 -> 4 -> 16 -> 64 -> 256 -> 1024 |
| | each block | |
| +----------+---------------+ |
| | hidden states |
| v |
| +--------------------------+ |
| | Decoder | ConvGRU + PixelShuffle (x num_blocks) |
| | Noise input (x M runs) | Each run produces one ensemble member |
| | for ensemble generation | |
| +----------+---------------+ |
| | |
| v |
| Output (B, T_future, M, H, W) |
| ``` |
|
|
| ## Docker |
|
|
| ```bash |
| docker build -t convgru-ensemble . |
| |
| # Run with local checkpoint |
| docker run -p 8000:8000 -v ./checkpoints:/app/checkpoints \ |
| -e MODEL_CHECKPOINT=/app/checkpoints/model.ckpt convgru-ensemble |
| |
| # Run with HuggingFace Hub |
| docker run -p 8000:8000 -e HF_REPO_ID=it4lia/irene convgru-ensemble |
| ``` |
|
|
| ## Project Structure |
|
|
| ``` |
| ConvGRU-Ensemble/ |
| +-- convgru_ensemble/ # Python package |
| | +-- model.py # ConvGRU encoder-decoder architecture |
| | +-- losses.py # CRPS, afCRPS, masked loss wrappers |
| | +-- lightning_model.py # PyTorch Lightning training module |
| | +-- datamodule.py # Dataset and data loading |
| | +-- train.py # Training entry point (Fiddle config) |
| | +-- utils.py # Rain rate <-> reflectivity conversions |
| | +-- hub.py # HuggingFace Hub upload/download |
| | +-- cli.py # CLI for inference and serving |
| | +-- serve.py # FastAPI inference server |
| +-- examples/ # Sample data for testing |
| +-- importance_sampler/ # Data preparation scripts |
| +-- notebooks/ # Example notebooks |
| +-- scripts/ # Utility scripts (e.g., upload to Hub) |
| +-- tests/ # Test suite |
| +-- Dockerfile # Container for serving API |
| +-- MODEL_CARD.md # HuggingFace model card template |
| ``` |
|
|
| ## Acknowledgements |
|
|
| <div align="center"> |
|
|
| Developed at **Fondazione Bruno Kessler (FBK)**, Trento, Italy, as part of the **Italian AI-Factory (IT4LIA)**, an EU-funded initiative supporting AI adoption across SMEs, academia, and public/private sectors. This work showcases capabilities in the **Earth (weather and climate) vertical domain**. |
|
|
| <br> |
|
|
| <table align="center" style="border: none; border-collapse: collapse;"> |
| <tr style="border: none;"> |
| <td colspan="2" align="center" style="border: none; padding: 8px;"><a href="https://it4lia-aifactory.eu"><img src="https://it4lia-aifactory.eu/wp-content/uploads/2025/05/logo-IT4LIA-AI-factory.svg" width="170" alt="IT4LIA AI-Factory"></a></td> |
| </tr> |
| <tr style="border: none;"> |
| <td align="center" style="border: none; padding: 8px;"><a href="https://www.fbk.eu"><img src="https://webvalley.fbk.eu/static/img/logos/fbk-logo-blue.png" width="100" alt="Fondazione Bruno Kessler"></a></td> |
| <td align="center" style="border: none; padding: 8px;"><a href="https://www.italiameteo.eu"><img src="https://it4lia-aifactory.eu/wp-content/uploads/2025/08/logo-italiameteo.svg" width="150" alt="ItaliaMeteo"></a></td> |
| </tr> |
| </table> |
|
|
| </div> |
|
|
| ## License |
|
|
| BSD 2-Clause — see [LICENSE](LICENSE). |
|
|