Upload source code and documentation
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- README.md +1213 -0
- data/download/__init__.py +6 -0
- data/download/ghcn_daily.py +517 -0
- data/download/ghcn_hourly.py +465 -0
- data/loaders/__init__.py +6 -0
- data/loaders/forecast_dataset.py +367 -0
- data/loaders/station_dataset.py +380 -0
- data/processed/ghcn_combined.parquet +3 -0
- data/processed/training/X.npy +3 -0
- data/processed/training/Y.npy +3 -0
- data/processed/training/meta.npy +3 -0
- data/processed/training/stats.npz +3 -0
- data/processing/__init__.py +6 -0
- data/processing/ghcn_processor.py +319 -0
- data/processing/pipeline.py +469 -0
- data/processing/quality_control.py +404 -0
- data/raw/ghcn_daily/ghcnd-inventory.txt +3 -0
- data/raw/ghcn_daily/ghcnd-stations.txt +3 -0
- data/raw/ghcn_daily/stations/USC00010063.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010148.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010160.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010163.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010178.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010252.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010260.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010267.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010369.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010377.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010390.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010395.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010402.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010407.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010422.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010425.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010430.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010505.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010583.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010616.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010655.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010757.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010764.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010823.dly +0 -0
- data/raw/ghcn_daily/stations/USC00010836.dly +0 -0
- data/raw/ghcn_daily/stations/USC00011069.dly +0 -0
- data/raw/ghcn_daily/stations/USC00011080.dly +0 -0
- data/raw/ghcn_daily/stations/USC00011084.dly +0 -0
- data/raw/ghcn_daily/stations/USC00011099.dly +0 -0
- data/raw/ghcn_daily/stations/USC00011189.dly +0 -0
- data/raw/ghcn_daily/stations/USC00011288.dly +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/raw/ghcn_daily/ghcnd-inventory.txt filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/raw/ghcn_daily/ghcnd-stations.txt filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,1213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
tags:
|
| 4 |
+
- weather
|
| 5 |
+
- time-series
|
| 6 |
+
- pytorch
|
| 7 |
+
- climate
|
| 8 |
+
license: apache-2.0
|
| 9 |
+
model-index:
|
| 10 |
+
- name: LILITH
|
| 11 |
+
results: []
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# L.I.L.I.T.H. (Long-range Intelligent Learning for Integrated Trend Hindcasting)
|
| 15 |
+
|
| 16 |
+
**A lightweight, open-source weather prediction model trained on GHCN data.**
|
| 17 |
+
|
| 18 |
+
<p align="center">
|
| 19 |
+
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="Python 3.10+">
|
| 20 |
+
<img src="https://img.shields.io/badge/PyTorch-2.1+-ee4c2c.svg" alt="PyTorch">
|
| 21 |
+
<img src="https://img.shields.io/badge/License-Apache%202.0-blue.svg" alt="License">
|
| 22 |
+
</p>
|
| 23 |
+
|
| 24 |
+
## Model Description
|
| 25 |
+
|
| 26 |
+
LILITH is a transformer-based weather forecasting model designed to run on consumer hardware (e.g., RTX 3060). It learns from 150+ years of station-based observations (GHCN-Daily) to predict 90-day temperature and precipitation trends with uncertainty quantification.
|
| 27 |
+
|
| 28 |
+
<p align="center">
|
| 29 |
+
<a href="#why-lilith">Why LILITH</a> •
|
| 30 |
+
<a href="#features">Features</a> •
|
| 31 |
+
<a href="#quick-start">Quick Start</a> •
|
| 32 |
+
<a href="#architecture">Architecture</a> •
|
| 33 |
+
<a href="#contributing">Contributing</a>
|
| 34 |
+
</p>
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
## The Weather Belongs to Everyone
|
| 39 |
+
|
| 40 |
+
Every day, corporations charge billions of dollars for weather forecasts built on **freely available public data**. The Global Historical Climatology Network (GHCN)—maintained by NOAA with taxpayer funding—contains over **150 years** of weather observations from **100,000+ stations worldwide**. This data is public domain. It belongs to humanity.
|
| 41 |
+
|
| 42 |
+
Yet somehow, we've accepted that accurate long-range forecasting should be locked behind enterprise paywalls and proprietary black boxes.
|
| 43 |
+
|
| 44 |
+
**LILITH exists to change that.**
|
| 45 |
+
|
| 46 |
+
With a single consumer GPU (RTX 3060, 12GB), you can now train and run a weather prediction model that delivers **90-day forecasts** with uncertainty quantification—the same capabilities that corporations charge premium prices for. No cloud subscriptions. No API limits. No black boxes.
|
| 47 |
+
|
| 48 |
+
```
|
| 49 |
+
┌────────────────────────────────────────────────────────────────────────────┐
|
| 50 |
+
│ │
|
| 51 |
+
│ "The same public data that corporations use to train billion-dollar │
|
| 52 |
+
│ weather systems is available to anyone with a GPU and curiosity." │
|
| 53 |
+
│ │
|
| 54 |
+
└────────────────────────────────────────────────────────────────────────────┘
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### The Data is Free. The Science is Open. The Code is Yours
|
| 58 |
+
|
| 59 |
+
| What Corporations Charge For | What LILITH Provides Free |
|
| 60 |
+
|------------------------------|---------------------------|
|
| 61 |
+
| 90-day extended forecasts | 90-day forecasts with uncertainty bands |
|
| 62 |
+
| "Proprietary" ML models | Fully transparent architecture |
|
| 63 |
+
| Enterprise API access | Self-hosted, unlimited queries |
|
| 64 |
+
| Historical climate analytics | 150+ years of GHCN data access |
|
| 65 |
+
| Per-query pricing | Run on your own hardware |
|
| 66 |
+
|
| 67 |
+
---
|
| 68 |
+
|
| 69 |
+
## Why LILITH
|
| 70 |
+
|
| 71 |
+
### The Problem
|
| 72 |
+
|
| 73 |
+
Modern weather AI (GraphCast, Pangu-Weather, FourCastNet) achieves remarkable accuracy, but:
|
| 74 |
+
|
| 75 |
+
- **Requires ERA5 reanalysis data** — computationally expensive to generate, controlled by ECMWF
|
| 76 |
+
- **Needs massive compute** — training requires hundreds of TPUs/GPUs
|
| 77 |
+
- **Inference is heavy** — full global models need 80GB+ VRAM
|
| 78 |
+
- **Closed ecosystems** — weights available, but practical deployment requires significant resources
|
| 79 |
+
|
| 80 |
+
### The Solution
|
| 81 |
+
|
| 82 |
+
LILITH takes a different approach:
|
| 83 |
+
|
| 84 |
+
1. **Station-Native Architecture** — Learns directly from sparse GHCN station observations instead of requiring gridded reanalysis
|
| 85 |
+
2. **Hierarchical Processing** — Graph attention for spatial relationships, spectral methods for global dynamics
|
| 86 |
+
3. **Memory Efficient** — Gradient checkpointing, INT8/INT4 quantization, runs on consumer GPUs
|
| 87 |
+
4. **Truly Open** — Apache 2.0 license, reproducible training, no hidden dependencies
|
| 88 |
+
|
| 89 |
+
---
|
| 90 |
+
|
| 91 |
+
## Features
|
| 92 |
+
|
| 93 |
+
### Core Capabilities
|
| 94 |
+
|
| 95 |
+
- **90-Day Forecasts** — Extended-range predictions competitive with commercial services
|
| 96 |
+
- **Uncertainty Quantification** — Know not just the prediction, but how confident it is
|
| 97 |
+
- **150+ Years of Data** — Built on the complete GHCN historical record
|
| 98 |
+
- **Global Coverage** — Forecasts for any location on Earth
|
| 99 |
+
- **Multiple Variables** — Temperature, precipitation, wind, pressure, humidity
|
| 100 |
+
|
| 101 |
+
### Technical Highlights
|
| 102 |
+
|
| 103 |
+
- **Consumer Hardware** — Inference on RTX 3060 (12GB), training on RTX 4090 or multi-GPU
|
| 104 |
+
- **Horizontally Scalable** — From laptop to cluster with Ray Serve
|
| 105 |
+
- **Modern Stack** — PyTorch 2.x, Flash Attention, DeepSpeed, FastAPI, Next.js 14
|
| 106 |
+
- **Production Ready** — Docker containers, Redis caching, PostgreSQL + TimescaleDB
|
| 107 |
+
|
| 108 |
+
### User Experience
|
| 109 |
+
|
| 110 |
+
- **Glassmorphic UI** — Beautiful, modern interface with dynamic weather backgrounds
|
| 111 |
+
- **Interactive Maps** — Mapbox GL JS with temperature layers and station markers
|
| 112 |
+
- **Rich Visualizations** — Recharts/D3 for forecasts, uncertainty bands, wind roses
|
| 113 |
+
- **Historical Explorer** — Analyze 150+ years of climate trends
|
| 114 |
+
|
| 115 |
+
---
|
| 116 |
+
|
| 117 |
+
## Quick Start
|
| 118 |
+
|
| 119 |
+
### Prerequisites
|
| 120 |
+
|
| 121 |
+
- Python 3.10+
|
| 122 |
+
- CUDA-capable GPU (12GB+ VRAM recommended)
|
| 123 |
+
- Node.js 18+ (for frontend)
|
| 124 |
+
|
| 125 |
+
### Quick Start with Pre-trained Model
|
| 126 |
+
|
| 127 |
+
If you have a trained checkpoint (e.g., `lilith_best.pt`), you can run the full stack immediately:
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
# 1. Clone and setup
|
| 131 |
+
git clone https://github.com/consigcody94/lilith.git
|
| 132 |
+
cd lilith
|
| 133 |
+
python -m venv .venv
|
| 134 |
+
.venv\Scripts\activate # Windows
|
| 135 |
+
# source .venv/bin/activate # Linux/Mac
|
| 136 |
+
|
| 137 |
+
# 2. Install dependencies
|
| 138 |
+
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
| 139 |
+
pip install -e ".[all]"
|
| 140 |
+
|
| 141 |
+
# 3. Place your checkpoint in the checkpoints folder
|
| 142 |
+
mkdir checkpoints
|
| 143 |
+
# Copy lilith_best.pt to checkpoints/
|
| 144 |
+
|
| 145 |
+
# 4. Set OpenWeatherMap API Key (Optional but recommended for live data)
|
| 146 |
+
export OPENWEATHER_API_KEY="your_api_key_here" # Linux/Mac
|
| 147 |
+
# set OPENWEATHER_API_KEY=your_api_key_here # Windows
|
| 148 |
+
|
| 149 |
+
# 5. Start the API server (auto-detects checkpoint)
|
| 150 |
+
python -m uvicorn web.api.main:app --host 127.0.0.1 --port 8000
|
| 151 |
+
|
| 152 |
+
# 6. In a new terminal, start the frontend
|
| 153 |
+
cd web/frontend
|
| 154 |
+
npm install
|
| 155 |
+
npm run dev
|
| 156 |
+
|
| 157 |
+
# 7. Open http://localhost:3000 in your browser
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
The API will automatically find and load `checkpoints/lilith_best.pt` or `checkpoints/lilith_final.pt`. You'll see log output like:
|
| 161 |
+
|
| 162 |
+
```
|
| 163 |
+
Found checkpoint at C:\...\checkpoints\lilith_best.pt
|
| 164 |
+
Model loaded on cuda
|
| 165 |
+
Config: d_model=128, layers=4
|
| 166 |
+
Val RMSE: 3.96°C
|
| 167 |
+
Model loaded successfully (RMSE: 3.96°C)
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
**Test the API directly:**
|
| 171 |
+
|
| 172 |
+
```bash
|
| 173 |
+
curl -X POST http://127.0.0.1:8000/v1/forecast \
|
| 174 |
+
-H "Content-Type: application/json" \
|
| 175 |
+
-d '{"latitude": 40.7128, "longitude": -74.006, "days": 14}'
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
### Installation
|
| 179 |
+
|
| 180 |
+
```bash
|
| 181 |
+
# Clone the repository
|
| 182 |
+
git clone https://github.com/consigcody94/lilith.git
|
| 183 |
+
cd lilith
|
| 184 |
+
|
| 185 |
+
# Create and activate virtual environment
|
| 186 |
+
python -m venv .venv
|
| 187 |
+
source .venv/bin/activate # Linux/Mac
|
| 188 |
+
# .venv\Scripts\activate # Windows
|
| 189 |
+
|
| 190 |
+
# Install with all dependencies
|
| 191 |
+
pip install -e ".[all]"
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
### Download Data
|
| 195 |
+
|
| 196 |
+
```bash
|
| 197 |
+
# Download GHCN-Daily station data
|
| 198 |
+
python scripts/download_data.py --source ghcn-daily --stations 5000 --years 50
|
| 199 |
+
|
| 200 |
+
# Process and prepare for training
|
| 201 |
+
python scripts/process_data.py --config configs/data/default.yaml
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
### Training
|
| 205 |
+
|
| 206 |
+
LILITH training is designed to work on consumer GPUs. Here's a complete step-by-step guide:
|
| 207 |
+
|
| 208 |
+
#### Step 1: Environment Setup
|
| 209 |
+
|
| 210 |
+
```bash
|
| 211 |
+
# Create and activate virtual environment
|
| 212 |
+
python -m venv .venv
|
| 213 |
+
.venv\Scripts\activate # Windows
|
| 214 |
+
# source .venv/bin/activate # Linux/Mac
|
| 215 |
+
|
| 216 |
+
# Install PyTorch with CUDA support
|
| 217 |
+
# For RTX 30/40 series:
|
| 218 |
+
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
| 219 |
+
|
| 220 |
+
# For RTX 50 series (Blackwell - requires nightly):
|
| 221 |
+
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
|
| 222 |
+
|
| 223 |
+
# Install LILITH dependencies
|
| 224 |
+
pip install -e ".[all]"
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
#### Step 2: Download Training Data
|
| 228 |
+
|
| 229 |
+
```bash
|
| 230 |
+
# Download GHCN station data (start with 300 stations for quick training)
|
| 231 |
+
python -m data.download.ghcn_daily \
|
| 232 |
+
--stations 300 \
|
| 233 |
+
--min-years 30 \
|
| 234 |
+
--country US
|
| 235 |
+
|
| 236 |
+
# For better models, download more stations
|
| 237 |
+
python -m data.download.ghcn_daily \
|
| 238 |
+
--stations 5000 \
|
| 239 |
+
--min-years 20 \
|
| 240 |
+
--elements TMAX,TMIN,PRCP
|
| 241 |
+
|
| 242 |
+
# Download climate indices for long-range prediction
|
| 243 |
+
python -m data.download.climate_indices --all
|
| 244 |
+
```
|
| 245 |
+
|
| 246 |
+
#### Step 3: Process Data
|
| 247 |
+
|
| 248 |
+
```bash
|
| 249 |
+
# Process raw GHCN data into training format
|
| 250 |
+
python -m data.processing.ghcn_processor
|
| 251 |
+
|
| 252 |
+
# This creates:
|
| 253 |
+
# - data/processed/ghcn_combined.parquet (all station data)
|
| 254 |
+
# - data/processed/training/X.npy (input sequences)
|
| 255 |
+
# - data/processed/training/Y.npy (target sequences)
|
| 256 |
+
# - data/processed/training/meta.npy (station metadata)
|
| 257 |
+
# - data/processed/training/stats.npz (normalization stats)
|
| 258 |
+
```
|
| 259 |
+
|
| 260 |
+
#### Step 4: Train the Model
|
| 261 |
+
|
| 262 |
+
```bash
|
| 263 |
+
# Quick training (30 epochs, good for testing)
|
| 264 |
+
python -m training.train_simple \
|
| 265 |
+
--epochs 30 \
|
| 266 |
+
--batch-size 64 \
|
| 267 |
+
--d-model 128 \
|
| 268 |
+
--layers 4
|
| 269 |
+
|
| 270 |
+
# Full training (100 epochs, production quality)
|
| 271 |
+
python -m training.train_simple \
|
| 272 |
+
--epochs 100 \
|
| 273 |
+
--batch-size 128 \
|
| 274 |
+
--d-model 256 \
|
| 275 |
+
--layers 6 \
|
| 276 |
+
--lr 1e-4
|
| 277 |
+
|
| 278 |
+
# Resume training from checkpoint
|
| 279 |
+
python -m training.train_simple \
|
| 280 |
+
--resume checkpoints/lilith_best.pt \
|
| 281 |
+
--epochs 50
|
| 282 |
+
```
|
| 283 |
+
|
| 284 |
+
#### Step 5: Monitor Training
|
| 285 |
+
|
| 286 |
+
During training, you'll see output like:
|
| 287 |
+
|
| 288 |
+
```
|
| 289 |
+
Epoch 1/30 | Train Loss: 0.8234 | Val Loss: 0.7891 | Temp RMSE: 4.21°C | Temp MAE: 3.15°C
|
| 290 |
+
Epoch 2/30 | Train Loss: 0.6543 | Val Loss: 0.6234 | Temp RMSE: 3.45°C | Temp MAE: 2.67°C
|
| 291 |
+
...
|
| 292 |
+
Epoch 30/30 | Train Loss: 0.2134 | Val Loss: 0.2456 | Temp RMSE: 1.89°C | Temp MAE: 1.42°C
|
| 293 |
+
```
|
| 294 |
+
|
| 295 |
+
Target metrics:
|
| 296 |
+
|
| 297 |
+
- **Days 1-7**: Temp RMSE < 2°C
|
| 298 |
+
- **Days 8-14**: Temp RMSE < 3°C
|
| 299 |
+
|
| 300 |
+
#### Step 6: Use the Trained Model
|
| 301 |
+
|
| 302 |
+
```bash
|
| 303 |
+
# Update the API to use your trained model
|
| 304 |
+
# Edit web/api/main.py and set DEMO_MODE = False
|
| 305 |
+
|
| 306 |
+
# Or run inference directly
|
| 307 |
+
python -m inference.forecast \
|
| 308 |
+
--checkpoint checkpoints/lilith_best.pt \
|
| 309 |
+
--lat 40.7128 --lon -74.006 \
|
| 310 |
+
--days 90
|
| 311 |
+
```
|
| 312 |
+
|
| 313 |
+
#### Training on Multiple GPUs
|
| 314 |
+
|
| 315 |
+
```bash
|
| 316 |
+
# Using PyTorch DistributedDataParallel
|
| 317 |
+
torchrun --nproc_per_node=2 training/train_distributed.py \
|
| 318 |
+
--config models/configs/large.yaml
|
| 319 |
+
|
| 320 |
+
# Using DeepSpeed for memory efficiency
|
| 321 |
+
deepspeed --num_gpus=4 training/train_deepspeed.py \
|
| 322 |
+
--config models/configs/xl.yaml \
|
| 323 |
+
--deepspeed configs/training/ds_config.json
|
| 324 |
+
```
|
| 325 |
+
|
| 326 |
+
#### Memory Requirements
|
| 327 |
+
|
| 328 |
+
| Model Size | Batch Size | VRAM Required |
|
| 329 |
+
|------------|------------|---------------|
|
| 330 |
+
| d_model=128 | 64 | ~4 GB |
|
| 331 |
+
| d_model=256 | 64 | ~8 GB |
|
| 332 |
+
| d_model=256 | 128 | ~12 GB |
|
| 333 |
+
| d_model=512 | 64 | ~16 GB |
|
| 334 |
+
|
| 335 |
+
#### Training Tips
|
| 336 |
+
|
| 337 |
+
1. **Start small**: Train with 300 stations first to verify everything works
|
| 338 |
+
2. **Monitor GPU usage**: Use `nvidia-smi` to ensure GPU is being utilized
|
| 339 |
+
3. **Watch for overfitting**: If val loss increases while train loss decreases, reduce epochs
|
| 340 |
+
4. **Save checkpoints**: The best model is automatically saved to `checkpoints/lilith_best.pt`
|
| 341 |
+
5. **Use mixed precision**: Enabled by default (FP16), cuts memory usage in half
|
| 342 |
+
|
| 343 |
+
---
|
| 344 |
+
|
| 345 |
+
## Pre-trained Models
|
| 346 |
+
|
| 347 |
+
### Using Pre-trained Checkpoints
|
| 348 |
+
|
| 349 |
+
Once a model is trained, you **do not need to retrain** — the checkpoint file contains everything needed for inference. Anyone can download and use pre-trained models.
|
| 350 |
+
|
| 351 |
+
#### Checkpoint File Contents
|
| 352 |
+
|
| 353 |
+
The `.pt` checkpoint file (~20-50MB depending on model size) contains:
|
| 354 |
+
|
| 355 |
+
```python
|
| 356 |
+
checkpoint = {
|
| 357 |
+
'epoch': 20, # Training epoch when saved
|
| 358 |
+
'model_state_dict': {...}, # All learned weights
|
| 359 |
+
'optimizer_state_dict': {...}, # Optimizer state (for resuming training)
|
| 360 |
+
'val_loss': 0.2456, # Validation loss at checkpoint
|
| 361 |
+
'val_rmse': 1.89, # Temperature RMSE in °C
|
| 362 |
+
'config': { # Model architecture config
|
| 363 |
+
'input_features': 3,
|
| 364 |
+
'output_features': 3,
|
| 365 |
+
'd_model': 128,
|
| 366 |
+
'nhead': 4,
|
| 367 |
+
'num_encoder_layers': 4,
|
| 368 |
+
'num_decoder_layers': 4,
|
| 369 |
+
'dropout': 0.1
|
| 370 |
+
},
|
| 371 |
+
'normalization': { # Data normalization stats
|
| 372 |
+
'X_mean': [...],
|
| 373 |
+
'X_std': [...],
|
| 374 |
+
'Y_mean': [...],
|
| 375 |
+
'Y_std': [...]
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
```
|
| 379 |
+
|
| 380 |
+
#### Pre-trained Checkpoint Included
|
| 381 |
+
|
| 382 |
+
A pre-trained checkpoint (`lilith_best.pt`) is included in the `checkpoints/` folder. This model was trained on:
|
| 383 |
+
|
| 384 |
+
- **915,000 sequences** from 300 US GHCN stations
|
| 385 |
+
- **20 epochs** of training
|
| 386 |
+
- **Validation RMSE: 3.96°C**
|
| 387 |
+
|
| 388 |
+
You can use this checkpoint immediately or train your own model with different data/parameters.
|
| 389 |
+
|
| 390 |
+
#### Model Specifications
|
| 391 |
+
|
| 392 |
+
| Model | Parameters | File Size | VRAM (Inference) | Best For |
|
| 393 |
+
|-------|------------|-----------|------------------|----------|
|
| 394 |
+
| **SimpleLILITH** | 1.87M | ~23 MB | 2-4 GB | Default model, fast training |
|
| 395 |
+
| **lilith-base** | 150M | ~45 MB | 4 GB | Balanced accuracy/speed |
|
| 396 |
+
| **lilith-large** | 400M | ~120 MB | 8 GB | High accuracy |
|
| 397 |
+
|
| 398 |
+
### GPU Requirements for Inference
|
| 399 |
+
|
| 400 |
+
Unlike training, inference requires much less VRAM. Here's what you can run on different hardware:
|
| 401 |
+
|
| 402 |
+
| GPU | VRAM | Models Supported | Batch Size | Latency (90-day forecast) |
|
| 403 |
+
|-----|------|------------------|------------|---------------------------|
|
| 404 |
+
| **RTX 3050/4050** | 4 GB | Tiny, Base (INT8) | 1 | ~1.5 sec |
|
| 405 |
+
| **RTX 3060/4060** | 8 GB | Tiny, Base, Large (INT8) | 1-4 | ~0.8 sec |
|
| 406 |
+
| **RTX 3070/4070** | 8-12 GB | All models (FP16) | 4-8 | ~0.5 sec |
|
| 407 |
+
| **RTX 3080/4080** | 10-16 GB | All models (FP16) | 8-16 | ~0.3 sec |
|
| 408 |
+
| **RTX 3090/4090** | 24 GB | All models, ensembles | 32+ | ~0.2 sec |
|
| 409 |
+
| **RTX 5050** | 8.5 GB | Tiny, Base, Large (INT8) | 1-4 | ~0.6 sec |
|
| 410 |
+
| **CPU Only** | N/A | All models (slow) | 1 | ~10-30 sec |
|
| 411 |
+
|
| 412 |
+
#### Quantization for Smaller GPUs
|
| 413 |
+
|
| 414 |
+
```bash
|
| 415 |
+
# Convert to INT8 for 50% memory reduction
|
| 416 |
+
python -m inference.quantize \
|
| 417 |
+
--checkpoint checkpoints/lilith_base.pt \
|
| 418 |
+
--output checkpoints/lilith_base_int8.pt \
|
| 419 |
+
--precision int8
|
| 420 |
+
|
| 421 |
+
# Convert to INT4 for 75% memory reduction (slight accuracy loss)
|
| 422 |
+
python -m inference.quantize \
|
| 423 |
+
--checkpoint checkpoints/lilith_base.pt \
|
| 424 |
+
--output checkpoints/lilith_base_int4.pt \
|
| 425 |
+
--precision int4
|
| 426 |
+
```
|
| 427 |
+
|
| 428 |
+
### Loading and Using a Checkpoint
|
| 429 |
+
|
| 430 |
+
#### Python API
|
| 431 |
+
|
| 432 |
+
```python
|
| 433 |
+
import torch
|
| 434 |
+
from models.lilith import SimpleLILITH
|
| 435 |
+
|
| 436 |
+
# Load checkpoint
|
| 437 |
+
checkpoint = torch.load('checkpoints/lilith_best.pt', map_location='cuda')
|
| 438 |
+
|
| 439 |
+
# Recreate model from config
|
| 440 |
+
model = SimpleLILITH(**checkpoint['config'])
|
| 441 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 442 |
+
model.eval()
|
| 443 |
+
|
| 444 |
+
# Get normalization stats
|
| 445 |
+
X_mean = torch.tensor(checkpoint['normalization']['X_mean'])
|
| 446 |
+
X_std = torch.tensor(checkpoint['normalization']['X_std'])
|
| 447 |
+
Y_mean = torch.tensor(checkpoint['normalization']['Y_mean'])
|
| 448 |
+
Y_std = torch.tensor(checkpoint['normalization']['Y_std'])
|
| 449 |
+
|
| 450 |
+
# Run inference
|
| 451 |
+
with torch.no_grad():
|
| 452 |
+
# Normalize input
|
| 453 |
+
X_norm = (X - X_mean) / X_std
|
| 454 |
+
|
| 455 |
+
# Predict
|
| 456 |
+
pred = model(X_norm, meta, target_len=14)
|
| 457 |
+
|
| 458 |
+
# Denormalize output
|
| 459 |
+
pred_denorm = pred * Y_std + Y_mean
|
| 460 |
+
```
|
| 461 |
+
|
| 462 |
+
#### Command Line
|
| 463 |
+
|
| 464 |
+
```bash
|
| 465 |
+
# Single location forecast
|
| 466 |
+
python -m inference.forecast \
|
| 467 |
+
--checkpoint checkpoints/lilith_best.pt \
|
| 468 |
+
--lat 40.7128 --lon -74.006 \
|
| 469 |
+
--days 90 \
|
| 470 |
+
--output forecast.json
|
| 471 |
+
|
| 472 |
+
# Batch inference for multiple locations
|
| 473 |
+
python -m inference.forecast \
|
| 474 |
+
--checkpoint checkpoints/lilith_best.pt \
|
| 475 |
+
--locations-file locations.csv \
|
| 476 |
+
--days 90 \
|
| 477 |
+
--output forecasts/
|
| 478 |
+
```
|
| 479 |
+
|
| 480 |
+
#### Start API Server with Trained Model
|
| 481 |
+
|
| 482 |
+
```bash
|
| 483 |
+
# Set checkpoint path
|
| 484 |
+
export LILITH_CHECKPOINT=checkpoints/lilith_best.pt
|
| 485 |
+
|
| 486 |
+
# Start API (will use trained model instead of demo mode)
|
| 487 |
+
python -m web.api.main
|
| 488 |
+
|
| 489 |
+
# Or specify directly
|
| 490 |
+
python -m uvicorn web.api.main:app --host 0.0.0.0 --port 8000
|
| 491 |
+
```
|
| 492 |
+
|
| 493 |
+
### Sharing Your Trained Model
|
| 494 |
+
|
| 495 |
+
#### Upload to HuggingFace Hub
|
| 496 |
+
|
| 497 |
+
```python
|
| 498 |
+
from huggingface_hub import HfApi
|
| 499 |
+
|
| 500 |
+
api = HfApi()
|
| 501 |
+
api.upload_file(
|
| 502 |
+
path_or_fileobj="checkpoints/lilith_best.pt",
|
| 503 |
+
path_in_repo="lilith_base_v1.pt",
|
| 504 |
+
repo_id="your-username/lilith-base",
|
| 505 |
+
repo_type="model"
|
| 506 |
+
)
|
| 507 |
+
```
|
| 508 |
+
|
| 509 |
+
#### Create a GitHub Release
|
| 510 |
+
|
| 511 |
+
```bash
|
| 512 |
+
# Tag your release
|
| 513 |
+
git tag -a v1.0 -m "LILITH Base v1.0 - Trained on 915K sequences"
|
| 514 |
+
git push origin v1.0
|
| 515 |
+
|
| 516 |
+
# Upload checkpoint to release (via GitHub UI or gh cli)
|
| 517 |
+
gh release create v1.0 checkpoints/lilith_best.pt --title "LILITH v1.0"
|
| 518 |
+
```
|
| 519 |
+
|
| 520 |
+
### Model Training Metrics
|
| 521 |
+
|
| 522 |
+
When training completes, you'll see metrics like:
|
| 523 |
+
|
| 524 |
+
```
|
| 525 |
+
┌────────────────────────────────────────────────────────────────┐
|
| 526 |
+
│ LILITH TRAINING COMPLETE │
|
| 527 |
+
├────────────────────────────────────────────────────────────────┤
|
| 528 |
+
│ Epochs: 20 │
|
| 529 |
+
│ Training Samples: 915,001 │
|
| 530 |
+
│ Final Train Loss: 0.2134 │
|
| 531 |
+
│ Final Val Loss: 0.2456 │
|
| 532 |
+
│ Temperature RMSE: 1.89°C │
|
| 533 |
+
│ Temperature MAE: 1.42°C │
|
| 534 |
+
│ Checkpoint: checkpoints/lilith_best.pt (22.8 MB) │
|
| 535 |
+
├────────────────────────────────────────────────────────────────┤
|
| 536 |
+
│ Model Config: │
|
| 537 |
+
│ - Parameters: 1,869,251 │
|
| 538 |
+
│ - d_model: 128 │
|
| 539 |
+
│ - Attention Heads: 4 │
|
| 540 |
+
│ - Encoder Layers: 4 │
|
| 541 |
+
│ - Decoder Layers: 4 │
|
| 542 |
+
└────────────────────────────────────────────────────────────────┘
|
| 543 |
+
```
|
| 544 |
+
|
| 545 |
+
### Resuming Training
|
| 546 |
+
|
| 547 |
+
```bash
|
| 548 |
+
# Continue training from checkpoint
|
| 549 |
+
python -m training.train_simple \
|
| 550 |
+
--resume checkpoints/lilith_best.pt \
|
| 551 |
+
--epochs 50 \
|
| 552 |
+
--lr 5e-5 # Lower learning rate for fine-tuning
|
| 553 |
+
|
| 554 |
+
# The checkpoint includes optimizer state, so training continues smoothly
|
| 555 |
+
```
|
| 556 |
+
|
| 557 |
+
### Model Comparison
|
| 558 |
+
|
| 559 |
+
| Checkpoint | Epochs | Training Data | Val RMSE | File Size | Notes |
|
| 560 |
+
|------------|--------|---------------|----------|-----------|-------|
|
| 561 |
+
| `lilith_v0.1.pt` | 10 | 100K samples | 4.3°C | 22 MB | Quick test |
|
| 562 |
+
| `lilith_v0.5.pt` | 30 | 500K samples | 2.8°C | 22 MB | Development |
|
| 563 |
+
| `lilith_v1.0.pt` | 100 | 915K samples | 1.9°C | 22 MB | Production |
|
| 564 |
+
| `lilith_large_v1.pt` | 100 | 2M samples | 1.5°C | 120 MB | Best accuracy |
|
| 565 |
+
|
| 566 |
+
---
|
| 567 |
+
|
| 568 |
+
### Inference
|
| 569 |
+
|
| 570 |
+
```bash
|
| 571 |
+
# Generate a forecast
|
| 572 |
+
python scripts/run_inference.py \
|
| 573 |
+
--checkpoint checkpoints/best.pt \
|
| 574 |
+
--lat 40.7128 --lon -74.006 \
|
| 575 |
+
--days 90
|
| 576 |
+
|
| 577 |
+
# Start the API server
|
| 578 |
+
python scripts/start_api.py --checkpoint checkpoints/best.pt --port 8000
|
| 579 |
+
|
| 580 |
+
# Query the API
|
| 581 |
+
curl -X POST http://localhost:8000/v1/forecast \
|
| 582 |
+
-H "Content-Type: application/json" \
|
| 583 |
+
-d '{"latitude": 40.7128, "longitude": -74.006, "days": 90}'
|
| 584 |
+
```
|
| 585 |
+
|
| 586 |
+
### Web Interface
|
| 587 |
+
|
| 588 |
+
```bash
|
| 589 |
+
cd web/frontend
|
| 590 |
+
npm install
|
| 591 |
+
npm run dev
|
| 592 |
+
# Open http://localhost:3000
|
| 593 |
+
```
|
| 594 |
+
|
| 595 |
+
### Docker Deployment
|
| 596 |
+
|
| 597 |
+
```bash
|
| 598 |
+
# Full stack deployment
|
| 599 |
+
docker-compose -f docker/docker-compose.yml up -d
|
| 600 |
+
|
| 601 |
+
# Individual services
|
| 602 |
+
docker build -f docker/Dockerfile.inference -t lilith-inference .
|
| 603 |
+
docker build -f docker/Dockerfile.web -t lilith-web .
|
| 604 |
+
```
|
| 605 |
+
|
| 606 |
+
---
|
| 607 |
+
|
| 608 |
+
## Architecture
|
| 609 |
+
|
| 610 |
+
### Model Overview
|
| 611 |
+
|
| 612 |
+
LILITH uses a **Station-Graph Temporal Transformer (SGTT)** architecture that processes weather observations through three stages:
|
| 613 |
+
|
| 614 |
+
```
|
| 615 |
+
┌─────────────────────────────────────────────────────────────────────────────┐
|
| 616 |
+
│ LILITH ARCHITECTURE │
|
| 617 |
+
├─────────────────────────────────────────────────────────────────────────────┤
|
| 618 |
+
│ │
|
| 619 |
+
│ INPUT: Station Observations │
|
| 620 |
+
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
| 621 |
+
│ │ • 100,000+ GHCN stations worldwide │ │
|
| 622 |
+
│ │ • Temperature, precipitation, pressure, wind, humidity │ │
|
| 623 |
+
│ │ • Quality-controlled, gap-filled, normalized │ │
|
| 624 |
+
│ └─────────────────────────────────────────────────────────────────────┘ │
|
| 625 |
+
│ │ │
|
| 626 |
+
│ ▼ │
|
| 627 |
+
│ ENCODER ────────────────────────────────────────────────────────────── │
|
| 628 |
+
│ ┌──────────────┐ ┌──────────────────┐ ┌────────────────────────┐ │
|
| 629 |
+
│ │ Station │──▶│ Graph Attention │──▶│ Temporal Transformer │ │
|
| 630 |
+
│ │ Embedding │ │ Network v2 │ │ (Flash Attention) │ │
|
| 631 |
+
│ │ │ │ │ │ │ │
|
| 632 |
+
│ │ • 3D pos │ │ • Spatial │ │ • Historical context │ │
|
| 633 |
+
│ │ • Features │ │ correlations │ │ • Causal masking │ │
|
| 634 |
+
│ │ • Temporal │ │ • Multi-hop │ │ • RoPE embeddings │ │
|
| 635 |
+
│ └──────────────┘ └──────────────────┘ └────────────────────────┘ │
|
| 636 |
+
│ │ │
|
| 637 |
+
│ ▼ │
|
| 638 |
+
│ ┌───────────────────────────────┐ │
|
| 639 |
+
│ │ LATENT ATMOSPHERIC STATE │ │
|
| 640 |
+
│ │ (64 × 128 × 256) │ │
|
| 641 |
+
│ │ │ │
|
| 642 |
+
│ │ Learned global grid that │ │
|
| 643 |
+
│ │ captures atmospheric │ │
|
| 644 |
+
│ │ dynamics implicitly │ │
|
| 645 |
+
│ └───────────────────────────────┘ │
|
| 646 |
+
│ │ │
|
| 647 |
+
│ ▼ │
|
| 648 |
+
│ PROCESSOR ──────────────────────────────────────────────────────────── │
|
| 649 |
+
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
| 650 |
+
│ │ Spherical Fourier Neural Operator (SFNO) │ │
|
| 651 |
+
│ │ │ │
|
| 652 |
+
│ │ • Operates in spectral domain on sphere │ │
|
| 653 |
+
│ │ • Captures global teleconnections (ENSO, NAO, etc.) │ │
|
| 654 |
+
│ │ • Respects Earth's spherical geometry │ │
|
| 655 |
+
│ │ • Efficient O(N log N) via spherical harmonics │ │
|
| 656 |
+
│ └─────────────────────────────────────────────────────────────────────┘ │
|
| 657 |
+
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
| 658 |
+
│ │ Multi-Scale Temporal Processor │ │
|
| 659 |
+
│ │ │ │
|
| 660 |
+
│ │ Days 1-14: 6-hour steps (synoptic weather) │ │
|
| 661 |
+
│ │ Days 15-42: 24-hour steps (weekly patterns) │ │
|
| 662 |
+
│ │ Days 43-90: 168-hour steps (seasonal trends) │ │
|
| 663 |
+
│ └─────────────────────────────────────────────────────────────────────┘ │
|
| 664 |
+
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
| 665 |
+
│ │ Climate Embedding Module │ │
|
| 666 |
+
│ │ │ │
|
| 667 |
+
│ │ • ENSO index (El Niño/La Niña state) │ │
|
| 668 |
+
│ │ • MJO phase and amplitude │ │
|
| 669 |
+
│ │ • NAO, AO, PDO indices │ │
|
| 670 |
+
│ │ • Seasonal cycles, solar position │ │
|
| 671 |
+
│ └─────────────────────────────────────────────────────────────────────┘ │
|
| 672 |
+
│ │ │
|
| 673 |
+
│ ▼ │
|
| 674 |
+
│ DECODER ────────────────────────────────────────────────────────────── │
|
| 675 |
+
│ ┌──────────────────────┐ ┌──────────────────────┐ │
|
| 676 |
+
│ │ Grid Decoder │ │ Station Decoder │ │
|
| 677 |
+
│ │ │ │ │ │
|
| 678 |
+
│ │ • Global fields │ │ • Point forecasts │ │
|
| 679 |
+
│ │ • Spatial upsampling│ │ • Location-specific │ │
|
| 680 |
+
│ └──────────────────────┘ └──────────────────────┘ │
|
| 681 |
+
│ │ │ │
|
| 682 |
+
│ ▼ ▼ │
|
| 683 |
+
│ OUTPUT ─────────────────────────────────────────────────────────────── │
|
| 684 |
+
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
| 685 |
+
│ │ Ensemble Head (Optional) │ │
|
| 686 |
+
│ │ │ │
|
| 687 |
+
│ │ • Diffusion-based ensemble generation │ │
|
| 688 |
+
│ │ • Gaussian, quantile, or MC dropout uncertainty │ │
|
| 689 |
+
│ │ • Calibrated confidence intervals │ │
|
| 690 |
+
│ └─────────────────────────────────────────────────────────────────────┘ │
|
| 691 |
+
│ │
|
| 692 |
+
│ FINAL OUTPUT: │
|
| 693 |
+
│ • 90-day forecasts for temperature, precipitation, wind, pressure │
|
| 694 |
+
│ • Uncertainty bounds (5th, 25th, 50th, 75th, 95th percentiles) │
|
| 695 |
+
│ • Ensemble spread metrics │
|
| 696 |
+
│ │
|
| 697 |
+
└─────────────────────────────────────────────────────────────────────────────┘
|
| 698 |
+
```
|
| 699 |
+
|
| 700 |
+
### Model Variants
|
| 701 |
+
|
| 702 |
+
| Variant | Parameters | VRAM (FP16) | VRAM (INT8) | Best For |
|
| 703 |
+
|---------|------------|-------------|-------------|----------|
|
| 704 |
+
| **LILITH-Tiny** | 50M | 4 GB | 2 GB | Fast inference, edge deployment |
|
| 705 |
+
| **LILITH-Base** | 150M | 8 GB | 4 GB | Balanced accuracy/speed |
|
| 706 |
+
| **LILITH-Large** | 400M | 12 GB | 6 GB | High accuracy forecasts |
|
| 707 |
+
| **LILITH-XL** | 1B | 24 GB | 12 GB | Research, maximum accuracy |
|
| 708 |
+
|
| 709 |
+
### Key Components
|
| 710 |
+
|
| 711 |
+
| Component | Purpose | Implementation |
|
| 712 |
+
|-----------|---------|----------------|
|
| 713 |
+
| `StationEmbedding` | Encode station features + position | MLP with 3D spherical coordinates |
|
| 714 |
+
| `GATEncoder` | Learn spatial relationships | Graph Attention Network v2 |
|
| 715 |
+
| `TemporalTransformer` | Process time series | Flash Attention with RoPE |
|
| 716 |
+
| `SFNO` | Global atmospheric dynamics | Spherical Fourier Neural Operator |
|
| 717 |
+
| `ClimateEmbedding` | Encode climate indices | ENSO, MJO, NAO, seasonal |
|
| 718 |
+
| `EnsembleHead` | Uncertainty quantification | Diffusion / Gaussian / Quantile |
|
| 719 |
+
|
| 720 |
+
---
|
| 721 |
+
|
| 722 |
+
## Data Sources
|
| 723 |
+
|
| 724 |
+
LILITH is built entirely on **freely available public data**. The more data sources you integrate, the better your predictions will be.
|
| 725 |
+
|
| 726 |
+
### Primary: GHCN (Global Historical Climatology Network)
|
| 727 |
+
|
| 728 |
+
| Dataset | Coverage | Stations | Variables | Resolution |
|
| 729 |
+
|---------|----------|----------|-----------|------------|
|
| 730 |
+
| **GHCN-Daily** | 1763–present | 100,000+ | Temp, Precip, Snow | Daily |
|
| 731 |
+
| **GHCN-Hourly** | 1900s–present | 20,000+ | Wind, Pressure, Humidity | Hourly |
|
| 732 |
+
| **GHCN-Monthly** | 1700s–present | 26,000 | Temp, Precip | Monthly |
|
| 733 |
+
|
| 734 |
+
**Source**: [NOAA National Centers for Environmental Information](https://www.ncei.noaa.gov/products/land-based-station/global-historical-climatology-network-daily)
|
| 735 |
+
|
| 736 |
+
### Recommended Additional Data Sources
|
| 737 |
+
|
| 738 |
+
These freely available datasets can significantly improve prediction accuracy:
|
| 739 |
+
|
| 740 |
+
#### 1. ERA5 Reanalysis (Highly Recommended)
|
| 741 |
+
|
| 742 |
+
| Dataset | Coverage | Resolution | Variables |
|
| 743 |
+
|---------|----------|------------|-----------|
|
| 744 |
+
| **ERA5** | 1940–present | 0.25° / hourly | Full atmospheric state (temperature, wind, humidity, pressure at all levels) |
|
| 745 |
+
|
| 746 |
+
**Source**: [ECMWF Climate Data Store](https://cds.climate.copernicus.eu/)
|
| 747 |
+
|
| 748 |
+
- Provides gridded global data interpolated from observations
|
| 749 |
+
- Excellent for learning atmospheric dynamics
|
| 750 |
+
- ~2TB for 10 years of data at full resolution
|
| 751 |
+
|
| 752 |
+
#### 2. Climate Indices (Essential for Long-Range)
|
| 753 |
+
|
| 754 |
+
| Index | Description | Impact |
|
| 755 |
+
|-------|-------------|--------|
|
| 756 |
+
| **ENSO (ONI)** | El Niño/La Niña state | Major driver of global weather patterns |
|
| 757 |
+
| **NAO** | North Atlantic Oscillation | European/North American winter weather |
|
| 758 |
+
| **PDO** | Pacific Decadal Oscillation | Long-term Pacific climate cycles |
|
| 759 |
+
| **MJO** | Madden-Julian Oscillation | Tropical weather, 30-60 day cycles |
|
| 760 |
+
| **AO** | Arctic Oscillation | Northern Hemisphere cold outbreaks |
|
| 761 |
+
|
| 762 |
+
**Source**: [NOAA Climate Prediction Center](https://www.cpc.ncep.noaa.gov/)
|
| 763 |
+
|
| 764 |
+
```bash
|
| 765 |
+
# Download climate indices
|
| 766 |
+
python -m data.download.climate_indices --indices enso,nao,pdo,mjo,ao
|
| 767 |
+
```
|
| 768 |
+
|
| 769 |
+
#### 3. Sea Surface Temperature (SST)
|
| 770 |
+
|
| 771 |
+
| Dataset | Coverage | Resolution |
|
| 772 |
+
|---------|----------|------------|
|
| 773 |
+
| **NOAA OISST** | 1981–present | 0.25° / daily |
|
| 774 |
+
| **HadISST** | 1870–present | 1° / monthly |
|
| 775 |
+
|
| 776 |
+
**Source**: [NOAA OISST](https://www.ncei.noaa.gov/products/optimum-interpolation-sst)
|
| 777 |
+
|
| 778 |
+
- Ocean temperatures strongly influence atmospheric patterns
|
| 779 |
+
- Critical for predicting precipitation and temperature anomalies
|
| 780 |
+
|
| 781 |
+
#### 4. NOAA GFS Model Data
|
| 782 |
+
|
| 783 |
+
| Dataset | Forecast Range | Resolution |
|
| 784 |
+
|---------|----------------|------------|
|
| 785 |
+
| **GFS Analysis** | Historical | 0.25° / 6-hourly |
|
| 786 |
+
| **GFS Forecasts** | 16 days | 0.25° / hourly |
|
| 787 |
+
|
| 788 |
+
**Source**: [NOAA NOMADS](https://nomads.ncep.noaa.gov/)
|
| 789 |
+
|
| 790 |
+
- Use as additional training signal or for ensemble weighting
|
| 791 |
+
- Can blend ML predictions with physics-based forecasts
|
| 792 |
+
|
| 793 |
+
#### 5. Satellite Data
|
| 794 |
+
|
| 795 |
+
| Dataset | Variables | Coverage |
|
| 796 |
+
|---------|-----------|----------|
|
| 797 |
+
| **GOES-16/17/18** | Cloud cover, precipitation | Americas |
|
| 798 |
+
| **NASA GPM** | Global precipitation | Global |
|
| 799 |
+
| **MODIS** | Land surface temperature | Global |
|
| 800 |
+
|
| 801 |
+
**Sources**:
|
| 802 |
+
|
| 803 |
+
- [NOAA CLASS](https://www.class.noaa.gov/)
|
| 804 |
+
- [NASA Earthdata](https://earthdata.nasa.gov/)
|
| 805 |
+
|
| 806 |
+
#### 6. Additional Reanalysis Products
|
| 807 |
+
|
| 808 |
+
| Dataset | Coverage | Best For |
|
| 809 |
+
|---------|----------|----------|
|
| 810 |
+
| **NASA MERRA-2** | 1980–present | North America |
|
| 811 |
+
| **NCEP/NCAR Reanalysis** | 1948–present | Historical coverage |
|
| 812 |
+
| **JRA-55** | 1958–present | Pacific/Asia region |
|
| 813 |
+
|
| 814 |
+
### Data Download Commands
|
| 815 |
+
|
| 816 |
+
```bash
|
| 817 |
+
# Download all recommended data sources
|
| 818 |
+
python -m data.download.all \
|
| 819 |
+
--ghcn-stations 5000 \
|
| 820 |
+
--era5-years 20 \
|
| 821 |
+
--climate-indices all \
|
| 822 |
+
--sst oisst \
|
| 823 |
+
--region north_america
|
| 824 |
+
|
| 825 |
+
# Download just climate indices (small, fast)
|
| 826 |
+
python -m data.download.climate_indices
|
| 827 |
+
|
| 828 |
+
# Download ERA5 for specific region (requires CDS account)
|
| 829 |
+
python -m data.download.era5 \
|
| 830 |
+
--start-year 2000 \
|
| 831 |
+
--end-year 2024 \
|
| 832 |
+
--region "north_america" \
|
| 833 |
+
--variables temperature,wind,humidity,pressure
|
| 834 |
+
```
|
| 835 |
+
|
| 836 |
+
### Data Integration Priority
|
| 837 |
+
|
| 838 |
+
For the best results, add data sources in this order:
|
| 839 |
+
|
| 840 |
+
1. **GHCN-Daily** (required) - Station observations
|
| 841 |
+
2. **Climate Indices** (highly recommended) - ENSO, NAO, MJO for long-range skill
|
| 842 |
+
3. **ERA5** (recommended) - Full atmospheric state for dynamics
|
| 843 |
+
4. **SST** (recommended) - Ocean influence on weather
|
| 844 |
+
5. **Satellite** (optional) - Real-time cloud/precip data
|
| 845 |
+
|
| 846 |
+
---
|
| 847 |
+
|
| 848 |
+
## Performance
|
| 849 |
+
|
| 850 |
+
### Accuracy Targets
|
| 851 |
+
|
| 852 |
+
| Forecast Range | Metric | LILITH Target | Climatology |
|
| 853 |
+
|----------------|--------|---------------|-------------|
|
| 854 |
+
| Days 1-7 | Temperature RMSE | < 2°C | ~5°C |
|
| 855 |
+
| Days 8-14 | Temperature RMSE | < 3°C | ~5°C |
|
| 856 |
+
| Days 15-42 | Skill Score | > 0.3 | 0.0 |
|
| 857 |
+
| Days 43-90 | Skill Score | > 0.1 | 0.0 |
|
| 858 |
+
|
| 859 |
+
### Inference Performance (RTX 3060 12GB)
|
| 860 |
+
|
| 861 |
+
| Model | Single Location | Regional Grid | Global |
|
| 862 |
+
|-------|-----------------|---------------|--------|
|
| 863 |
+
| LILITH-Tiny (INT8) | 0.3s | 2s | 15s |
|
| 864 |
+
| LILITH-Base (INT8) | 0.8s | 5s | 45s |
|
| 865 |
+
| LILITH-Large (FP16) | 1.5s | 12s | 90s |
|
| 866 |
+
|
| 867 |
+
---
|
| 868 |
+
|
| 869 |
+
## Project Structure
|
| 870 |
+
|
| 871 |
+
```
|
| 872 |
+
lilith/
|
| 873 |
+
├── data/ # Data pipeline
|
| 874 |
+
│ ├── download/ # GHCN download scripts
|
| 875 |
+
│ │ ├── ghcn_daily.py # Daily observations
|
| 876 |
+
│ │ └── ghcn_hourly.py # Hourly observations
|
| 877 |
+
│ ├── processing/ # Data processing
|
| 878 |
+
│ │ ├── quality_control.py # Outlier detection, QC flags
|
| 879 |
+
│ │ ├── feature_encoder.py # Normalization, encoding
|
| 880 |
+
│ │ └── gridding.py # Station → grid interpolation
|
| 881 |
+
│ └── loaders/ # PyTorch datasets
|
| 882 |
+
│ ├── station_dataset.py # Station-based loading
|
| 883 |
+
│ └── forecast_dataset.py # Forecast sequence loading
|
| 884 |
+
│
|
| 885 |
+
├── models/ # Model architecture
|
| 886 |
+
│ ├── components/ # Building blocks
|
| 887 |
+
│ │ ├── station_embed.py # Station feature embedding
|
| 888 |
+
│ │ ├── gat_encoder.py # Graph Attention Network
|
| 889 |
+
│ │ ├── temporal_transformer.py # Temporal processing
|
| 890 |
+
│ │ ├── sfno.py # Spherical Fourier Neural Operator
|
| 891 |
+
│ │ ├── climate_embed.py # Climate indices embedding
|
| 892 |
+
│ │ └── ensemble_head.py # Uncertainty quantification
|
| 893 |
+
│ ├── lilith.py # Main model class
|
| 894 |
+
│ ├── losses.py # Multi-task loss functions
|
| 895 |
+
│ └── configs/ # Model configurations
|
| 896 |
+
│ ├── tiny.yaml
|
| 897 |
+
│ ├── base.yaml
|
| 898 |
+
│ └── large.yaml
|
| 899 |
+
│
|
| 900 |
+
├── training/ # Training infrastructure
|
| 901 |
+
│ └── trainer.py # Training loop with DeepSpeed
|
| 902 |
+
│
|
| 903 |
+
├── inference/ # Inference and serving
|
| 904 |
+
│ ├── forecast.py # High-level forecast API
|
| 905 |
+
│ └── quantize.py # INT8/INT4 quantization
|
| 906 |
+
│
|
| 907 |
+
├── web/
|
| 908 |
+
│ ├── api/ # FastAPI backend
|
| 909 |
+
│ │ ├── main.py # Application entry point
|
| 910 |
+
│ │ └── schemas.py # Pydantic models
|
| 911 |
+
│ └── frontend/ # Next.js 14 frontend
|
| 912 |
+
│ └── src/
|
| 913 |
+
│ ├── app/ # App Router pages
|
| 914 |
+
│ ├── components/ # React components
|
| 915 |
+
│ └── stores/ # Zustand state
|
| 916 |
+
│
|
| 917 |
+
├── scripts/ # CLI utilities
|
| 918 |
+
│ ├── download_data.py
|
| 919 |
+
│ ├── process_data.py
|
| 920 |
+
│ ├── train_model.py
|
| 921 |
+
│ ├── run_inference.py
|
| 922 |
+
│ └── start_api.py
|
| 923 |
+
│
|
| 924 |
+
├── tests/ # Test suite
|
| 925 |
+
│ ├── test_models.py
|
| 926 |
+
│ ├── test_data.py
|
| 927 |
+
│ └── test_api.py
|
| 928 |
+
│
|
| 929 |
+
├── docker/ # Containerization
|
| 930 |
+
│ ├── Dockerfile.inference
|
| 931 |
+
│ ├── Dockerfile.web
|
| 932 |
+
│ └── docker-compose.yml
|
| 933 |
+
│
|
| 934 |
+
└── docs/ # Documentation
|
| 935 |
+
└── architecture.md
|
| 936 |
+
```
|
| 937 |
+
|
| 938 |
+
---
|
| 939 |
+
|
| 940 |
+
## API Reference
|
| 941 |
+
|
| 942 |
+
### Endpoints
|
| 943 |
+
|
| 944 |
+
#### `POST /v1/forecast`
|
| 945 |
+
|
| 946 |
+
Generate a weather forecast for a location.
|
| 947 |
+
|
| 948 |
+
```json
|
| 949 |
+
{
|
| 950 |
+
"latitude": 40.7128,
|
| 951 |
+
"longitude": -74.006,
|
| 952 |
+
"days": 90,
|
| 953 |
+
"ensemble_members": 10,
|
| 954 |
+
"variables": ["temperature", "precipitation", "wind"]
|
| 955 |
+
}
|
| 956 |
+
```
|
| 957 |
+
|
| 958 |
+
**Response:**
|
| 959 |
+
|
| 960 |
+
```json
|
| 961 |
+
{
|
| 962 |
+
"location": {"latitude": 40.7128, "longitude": -74.006, "name": "New York, NY"},
|
| 963 |
+
"generated_at": "2025-01-15T12:00:00Z",
|
| 964 |
+
"model_version": "lilith-base-v1.0",
|
| 965 |
+
"forecasts": [
|
| 966 |
+
{
|
| 967 |
+
"date": "2025-01-16",
|
| 968 |
+
"temperature": {"mean": 2.5, "min": -1.2, "max": 6.8},
|
| 969 |
+
"precipitation": {"probability": 0.35, "amount_mm": 2.1},
|
| 970 |
+
"wind": {"speed_ms": 5.2, "direction_deg": 270},
|
| 971 |
+
"uncertainty": {"temperature_std": 1.2, "confidence": 0.85}
|
| 972 |
+
}
|
| 973 |
+
]
|
| 974 |
+
}
|
| 975 |
+
```
|
| 976 |
+
|
| 977 |
+
#### `GET /v1/historical/{station_id}`
|
| 978 |
+
|
| 979 |
+
Retrieve historical observations for a station.
|
| 980 |
+
|
| 981 |
+
#### `GET /health`
|
| 982 |
+
|
| 983 |
+
Health check endpoint.
|
| 984 |
+
|
| 985 |
+
---
|
| 986 |
+
|
| 987 |
+
## Contributing
|
| 988 |
+
|
| 989 |
+
We welcome contributions from the community. LILITH is built on the principle that weather forecasting should be accessible to everyone, and that means building in the open with help from anyone who shares that vision.
|
| 990 |
+
|
| 991 |
+
### Ways to Contribute
|
| 992 |
+
|
| 993 |
+
- **Code**: Model improvements, new features, bug fixes
|
| 994 |
+
- **Data**: Additional data sources, quality control improvements
|
| 995 |
+
- **Documentation**: Tutorials, guides, API documentation
|
| 996 |
+
- **Testing**: Unit tests, integration tests, benchmarking
|
| 997 |
+
- **Design**: UI/UX improvements, visualizations
|
| 998 |
+
|
| 999 |
+
### Development Setup
|
| 1000 |
+
|
| 1001 |
+
```bash
|
| 1002 |
+
# Fork and clone (replace with your username if you fork)
|
| 1003 |
+
git clone https://github.com/consigcody94/lilith.git
|
| 1004 |
+
cd lilith
|
| 1005 |
+
|
| 1006 |
+
# Install development dependencies
|
| 1007 |
+
pip install -e ".[dev]"
|
| 1008 |
+
|
| 1009 |
+
# Install pre-commit hooks
|
| 1010 |
+
pre-commit install
|
| 1011 |
+
|
| 1012 |
+
# Run tests
|
| 1013 |
+
pytest tests/ -v
|
| 1014 |
+
|
| 1015 |
+
# Run linting
|
| 1016 |
+
ruff check .
|
| 1017 |
+
mypy .
|
| 1018 |
+
```
|
| 1019 |
+
|
| 1020 |
+
### Pull Request Process
|
| 1021 |
+
|
| 1022 |
+
1. Fork the repository
|
| 1023 |
+
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
|
| 1024 |
+
3. Make your changes
|
| 1025 |
+
4. Run tests and linting
|
| 1026 |
+
5. Commit with clear messages
|
| 1027 |
+
6. Push and open a Pull Request
|
| 1028 |
+
|
| 1029 |
+
---
|
| 1030 |
+
|
| 1031 |
+
## Acknowledgments
|
| 1032 |
+
|
| 1033 |
+
### U.S. Government AI Initiatives
|
| 1034 |
+
|
| 1035 |
+
We thank **President Donald Trump** and his administration for the **Stargate AI Initiative** and commitment to advancing American AI research and infrastructure. The recognition that AI development—including open-source projects like LILITH—represents a critical frontier for innovation, economic growth, and global competitiveness has helped create an environment where ambitious projects like this can flourish. The initiative's focus on building domestic AI capabilities and infrastructure supports the democratization of advanced technologies for all Americans.
|
| 1036 |
+
|
| 1037 |
+
### Data Providers
|
| 1038 |
+
|
| 1039 |
+
- **NOAA NCEI** — For maintaining the invaluable GHCN dataset as a public resource funded by U.S. taxpayers
|
| 1040 |
+
- **ECMWF** — For ERA5 reanalysis data
|
| 1041 |
+
|
| 1042 |
+
### Research Community
|
| 1043 |
+
|
| 1044 |
+
- **GraphCast** (Google DeepMind) — Pioneering ML weather prediction
|
| 1045 |
+
- **Pangu-Weather** (Huawei) — Advancing transformer architectures for weather
|
| 1046 |
+
- **FourCastNet** (NVIDIA) — Demonstrating Fourier neural operators for atmospheric modeling
|
| 1047 |
+
- **FuXi** (Fudan University) — Pushing boundaries in subseasonal forecasting
|
| 1048 |
+
|
| 1049 |
+
### Open Source
|
| 1050 |
+
|
| 1051 |
+
- PyTorch team for the deep learning framework
|
| 1052 |
+
- Hugging Face for model hosting infrastructure
|
| 1053 |
+
- The countless contributors to the Python scientific computing ecosystem
|
| 1054 |
+
|
| 1055 |
+
---
|
| 1056 |
+
|
| 1057 |
+
## Configuration
|
| 1058 |
+
|
| 1059 |
+
### Environment Variables
|
| 1060 |
+
|
| 1061 |
+
Copy `.env.example` to `.env` and configure:
|
| 1062 |
+
|
| 1063 |
+
```bash
|
| 1064 |
+
cp .env.example .env
|
| 1065 |
+
```
|
| 1066 |
+
|
| 1067 |
+
| Variable | Required | Default | Description |
|
| 1068 |
+
|----------|----------|---------|-------------|
|
| 1069 |
+
| `OPENWEATHER_API_KEY` | Yes (for live data) | `YOUR_OPENWEATHER_API_KEY_HERE` | Free API key from [OpenWeatherMap](https://openweathermap.org/api) |
|
| 1070 |
+
| `LILITH_CHECKPOINT` | No | Auto-detected | Path to trained model checkpoint |
|
| 1071 |
+
|
| 1072 |
+
### Getting an OpenWeatherMap API Key
|
| 1073 |
+
|
| 1074 |
+
1. Sign up at [OpenWeatherMap](https://openweathermap.org/users/sign_up) (free)
|
| 1075 |
+
2. Go to [API Keys](https://home.openweathermap.org/api_keys)
|
| 1076 |
+
3. Copy your API key
|
| 1077 |
+
4. Set the environment variable:
|
| 1078 |
+
|
| 1079 |
+
```bash
|
| 1080 |
+
# Linux/Mac
|
| 1081 |
+
export OPENWEATHER_API_KEY="your_key_here"
|
| 1082 |
+
|
| 1083 |
+
# Windows PowerShell
|
| 1084 |
+
$env:OPENWEATHER_API_KEY="your_key_here"
|
| 1085 |
+
|
| 1086 |
+
# Windows CMD
|
| 1087 |
+
set OPENWEATHER_API_KEY=your_key_here
|
| 1088 |
+
```
|
| 1089 |
+
|
| 1090 |
+
### Using the Pre-trained Model
|
| 1091 |
+
|
| 1092 |
+
A pre-trained model is available in the releases. This model was trained on:
|
| 1093 |
+
|
| 1094 |
+
- **505 US GHCN stations** with 9.6 million weather records
|
| 1095 |
+
- **1.15 million training sequences**
|
| 1096 |
+
- **10 epochs** of training (~5 hours on CPU, ~1 hour on GPU)
|
| 1097 |
+
- **Final RMSE: 3.88°C** (temperature prediction accuracy)
|
| 1098 |
+
|
| 1099 |
+
Download and use:
|
| 1100 |
+
|
| 1101 |
+
```bash
|
| 1102 |
+
# Download from releases
|
| 1103 |
+
curl -L -o checkpoints/lilith_best.pt https://github.com/consigcody94/lilith/releases/download/v1.0/lilith_best.pt
|
| 1104 |
+
|
| 1105 |
+
# Start with the model
|
| 1106 |
+
LILITH_CHECKPOINT=checkpoints/lilith_best.pt python -m uvicorn web.api.main:app --port 8000
|
| 1107 |
+
```
|
| 1108 |
+
|
| 1109 |
+
### Live Data & Caching
|
| 1110 |
+
|
| 1111 |
+
LILITH fetches live data from external APIs. To avoid hitting rate limits:
|
| 1112 |
+
|
| 1113 |
+
#### OpenWeatherMap (Forecast Adjustments)
|
| 1114 |
+
|
| 1115 |
+
- **Source**: api.openweathermap.org
|
| 1116 |
+
- **Cache**: 15 minutes per location
|
| 1117 |
+
- **Rate Limit**: 1,000 calls/day on free tier
|
| 1118 |
+
- Used for fallback forecasts when ML model is unavailable
|
| 1119 |
+
|
| 1120 |
+
To disable live data fetching entirely and use only the ML model:
|
| 1121 |
+
|
| 1122 |
+
```python
|
| 1123 |
+
# In web/api/main.py, set _weather_service to None
|
| 1124 |
+
_weather_service = None # Disables OpenWeatherMap calls
|
| 1125 |
+
```
|
| 1126 |
+
|
| 1127 |
+
### Running Without API Keys
|
| 1128 |
+
|
| 1129 |
+
If you don't want to set up API keys, the app will still work but with limited features:
|
| 1130 |
+
|
| 1131 |
+
| Feature | With API Key | Without API Key |
|
| 1132 |
+
|---------|--------------|-----------------|
|
| 1133 |
+
| ML Forecasts | ✅ Full functionality | ✅ Full functionality |
|
| 1134 |
+
| Fallback Forecasts | ✅ OWM-based | ❌ Error if model not loaded |
|
| 1135 |
+
|
| 1136 |
+
### Data Directory Structure
|
| 1137 |
+
|
| 1138 |
+
```
|
| 1139 |
+
data/
|
| 1140 |
+
├── raw/
|
| 1141 |
+
│ └── ghcn_daily/ # Downloaded GHCN station files
|
| 1142 |
+
│ ├── stations/ # .dly files (gitignored)
|
| 1143 |
+
│ ├── ghcnd-stations.txt
|
| 1144 |
+
│ └── ghcnd-inventory.txt
|
| 1145 |
+
├── processed/
|
| 1146 |
+
│ └── training/ # Processed training data (gitignored)
|
| 1147 |
+
│ ├── X.npy # Input sequences
|
| 1148 |
+
│ ├── Y.npy # Target sequences
|
| 1149 |
+
│ └── stats.npz # Normalization stats
|
| 1150 |
+
└── training_stations.json # Station coordinates (500+ stations)
|
| 1151 |
+
|
| 1152 |
+
checkpoints/
|
| 1153 |
+
├── lilith_best.pt # Best model checkpoint
|
| 1154 |
+
└── lilith_*.pt # Other checkpoints (gitignored)
|
| 1155 |
+
```
|
| 1156 |
+
|
| 1157 |
+
### Avoiding Data Re-downloads
|
| 1158 |
+
|
| 1159 |
+
Training data is cached locally. To avoid re-downloading on every build:
|
| 1160 |
+
|
| 1161 |
+
```bash
|
| 1162 |
+
# Check if data exists before downloading
|
| 1163 |
+
if [ ! -d "data/raw/ghcn_daily/stations" ]; then
|
| 1164 |
+
python scripts/download_data.py --max-stations 500
|
| 1165 |
+
fi
|
| 1166 |
+
|
| 1167 |
+
# Or use the --skip-existing flag
|
| 1168 |
+
python scripts/download_data.py --max-stations 500 --skip-existing
|
| 1169 |
+
```
|
| 1170 |
+
|
| 1171 |
+
---
|
| 1172 |
+
|
| 1173 |
+
## License
|
| 1174 |
+
|
| 1175 |
+
```
|
| 1176 |
+
Copyright 2025 LILITH Contributors
|
| 1177 |
+
|
| 1178 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 1179 |
+
you may not use this file except in compliance with the License.
|
| 1180 |
+
You may obtain a copy of the License at
|
| 1181 |
+
|
| 1182 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 1183 |
+
|
| 1184 |
+
Unless required by applicable law or agreed to in writing, software
|
| 1185 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 1186 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 1187 |
+
See the License for the specific language governing permissions and
|
| 1188 |
+
limitations under the License.
|
| 1189 |
+
```
|
| 1190 |
+
|
| 1191 |
+
---
|
| 1192 |
+
|
| 1193 |
+
## Citation
|
| 1194 |
+
|
| 1195 |
+
If you use LILITH in your research, please cite:
|
| 1196 |
+
|
| 1197 |
+
```bibtex
|
| 1198 |
+
@software{lilith2025,
|
| 1199 |
+
author = {LILITH Contributors},
|
| 1200 |
+
title = {LILITH: Long-range Intelligent Learning for Integrated Trend Hindcasting},
|
| 1201 |
+
year = {2025},
|
| 1202 |
+
url = {https://github.com/consigcody94/lilith}
|
| 1203 |
+
}
|
| 1204 |
+
```
|
| 1205 |
+
|
| 1206 |
+
---
|
| 1207 |
+
|
| 1208 |
+
<p align="center">
|
| 1209 |
+
<br>
|
| 1210 |
+
<em>"The storm goddess sees all horizons."</em>
|
| 1211 |
+
<br><br>
|
| 1212 |
+
<strong>Weather prediction should be free. The data is public. The science is open. Now the tools are too.</strong>
|
| 1213 |
+
</p>
|
data/download/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GHCN Data Download Scripts."""
|
| 2 |
+
|
| 3 |
+
from data.download.ghcn_daily import GHCNDailyDownloader
|
| 4 |
+
from data.download.ghcn_hourly import GHCNHourlyDownloader
|
| 5 |
+
|
| 6 |
+
__all__ = ["GHCNDailyDownloader", "GHCNHourlyDownloader"]
|
data/download/ghcn_daily.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GHCN-Daily Data Downloader
|
| 3 |
+
|
| 4 |
+
Downloads and parses GHCN-Daily data from NOAA NCEI.
|
| 5 |
+
https://www.ncei.noaa.gov/products/land-based-station/global-historical-climatology-network-daily
|
| 6 |
+
|
| 7 |
+
Data format documentation:
|
| 8 |
+
https://www.ncei.noaa.gov/pub/data/ghcn/daily/readme.txt
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import gzip
|
| 12 |
+
import re
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Generator, Optional, List, Union
|
| 17 |
+
|
| 18 |
+
import httpx
|
| 19 |
+
import pandas as pd
|
| 20 |
+
from loguru import logger
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class Station:
|
| 26 |
+
"""GHCN Station metadata."""
|
| 27 |
+
|
| 28 |
+
id: str
|
| 29 |
+
latitude: float
|
| 30 |
+
longitude: float
|
| 31 |
+
elevation: float
|
| 32 |
+
state: Optional[str]
|
| 33 |
+
name: str
|
| 34 |
+
gsn_flag: Optional[str]
|
| 35 |
+
hcn_flag: Optional[str]
|
| 36 |
+
wmo_id: Optional[str]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class DailyObservation:
|
| 41 |
+
"""Single daily observation record."""
|
| 42 |
+
|
| 43 |
+
station_id: str
|
| 44 |
+
date: datetime
|
| 45 |
+
element: str # TMAX, TMIN, PRCP, SNOW, SNWD, etc.
|
| 46 |
+
value: float
|
| 47 |
+
m_flag: Optional[str] # Measurement flag
|
| 48 |
+
q_flag: Optional[str] # Quality flag
|
| 49 |
+
s_flag: Optional[str] # Source flag
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class GHCNDailyDownloader:
|
| 53 |
+
"""
|
| 54 |
+
Downloads and parses GHCN-Daily data.
|
| 55 |
+
|
| 56 |
+
GHCN-Daily contains daily climate summaries from land surface stations
|
| 57 |
+
across the globe, with records from over 100,000 stations in 180 countries.
|
| 58 |
+
|
| 59 |
+
Example usage:
|
| 60 |
+
downloader = GHCNDailyDownloader(output_dir="data/raw/ghcn_daily")
|
| 61 |
+
downloader.download_stations()
|
| 62 |
+
downloader.download_inventory()
|
| 63 |
+
|
| 64 |
+
# Download data for specific stations
|
| 65 |
+
for station in downloader.get_stations(country="US", min_years=50):
|
| 66 |
+
downloader.download_station_data(station.id)
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
BASE_URL = "https://www.ncei.noaa.gov/pub/data/ghcn/daily/"
|
| 70 |
+
|
| 71 |
+
# Element codes we care about
|
| 72 |
+
ELEMENTS = {
|
| 73 |
+
"TMAX": "Maximum temperature (tenths of degrees C)",
|
| 74 |
+
"TMIN": "Minimum temperature (tenths of degrees C)",
|
| 75 |
+
"PRCP": "Precipitation (tenths of mm)",
|
| 76 |
+
"SNOW": "Snowfall (mm)",
|
| 77 |
+
"SNWD": "Snow depth (mm)",
|
| 78 |
+
"AWND": "Average daily wind speed (tenths of m/s)",
|
| 79 |
+
"TAVG": "Average temperature (tenths of degrees C)",
|
| 80 |
+
"RHAV": "Average relative humidity (%)",
|
| 81 |
+
"RHMX": "Maximum relative humidity (%)",
|
| 82 |
+
"RHMN": "Minimum relative humidity (%)",
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
output_dir: Union[str, Path] = "data/raw/ghcn_daily",
|
| 88 |
+
timeout: float = 60.0,
|
| 89 |
+
):
|
| 90 |
+
self.output_dir = Path(output_dir)
|
| 91 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 92 |
+
self.timeout = timeout
|
| 93 |
+
self._client: Optional[httpx.Client] = None
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def client(self) -> httpx.Client:
|
| 97 |
+
"""Lazy-initialized HTTP client."""
|
| 98 |
+
if self._client is None:
|
| 99 |
+
self._client = httpx.Client(timeout=self.timeout, follow_redirects=True)
|
| 100 |
+
return self._client
|
| 101 |
+
|
| 102 |
+
def __enter__(self) -> "GHCNDailyDownloader":
|
| 103 |
+
return self
|
| 104 |
+
|
| 105 |
+
def __exit__(self, *args) -> None:
|
| 106 |
+
if self._client:
|
| 107 |
+
self._client.close()
|
| 108 |
+
|
| 109 |
+
def download_stations(self, force: bool = False) -> Path:
|
| 110 |
+
"""
|
| 111 |
+
Download station metadata file (ghcnd-stations.txt).
|
| 112 |
+
|
| 113 |
+
Returns path to the downloaded file.
|
| 114 |
+
"""
|
| 115 |
+
url = f"{self.BASE_URL}ghcnd-stations.txt"
|
| 116 |
+
output_path = self.output_dir / "ghcnd-stations.txt"
|
| 117 |
+
|
| 118 |
+
if output_path.exists() and not force:
|
| 119 |
+
logger.info(f"Stations file already exists: {output_path}")
|
| 120 |
+
return output_path
|
| 121 |
+
|
| 122 |
+
logger.info(f"Downloading stations from {url}")
|
| 123 |
+
response = self.client.get(url)
|
| 124 |
+
response.raise_for_status()
|
| 125 |
+
|
| 126 |
+
output_path.write_text(response.text)
|
| 127 |
+
logger.success(f"Downloaded stations to {output_path}")
|
| 128 |
+
return output_path
|
| 129 |
+
|
| 130 |
+
def download_inventory(self, force: bool = False) -> Path:
|
| 131 |
+
"""
|
| 132 |
+
Download station inventory file (ghcnd-inventory.txt).
|
| 133 |
+
|
| 134 |
+
The inventory shows which elements are available for each station
|
| 135 |
+
and the period of record.
|
| 136 |
+
"""
|
| 137 |
+
url = f"{self.BASE_URL}ghcnd-inventory.txt"
|
| 138 |
+
output_path = self.output_dir / "ghcnd-inventory.txt"
|
| 139 |
+
|
| 140 |
+
if output_path.exists() and not force:
|
| 141 |
+
logger.info(f"Inventory file already exists: {output_path}")
|
| 142 |
+
return output_path
|
| 143 |
+
|
| 144 |
+
logger.info(f"Downloading inventory from {url}")
|
| 145 |
+
response = self.client.get(url)
|
| 146 |
+
response.raise_for_status()
|
| 147 |
+
|
| 148 |
+
output_path.write_text(response.text)
|
| 149 |
+
logger.success(f"Downloaded inventory to {output_path}")
|
| 150 |
+
return output_path
|
| 151 |
+
|
| 152 |
+
def parse_stations(self, path: Optional[Path] = None) -> List[Station]:
|
| 153 |
+
"""
|
| 154 |
+
Parse the stations metadata file.
|
| 155 |
+
|
| 156 |
+
Format (fixed-width):
|
| 157 |
+
ID 1-11 Character
|
| 158 |
+
LATITUDE 13-20 Real
|
| 159 |
+
LONGITUDE 22-30 Real
|
| 160 |
+
ELEVATION 32-37 Real
|
| 161 |
+
STATE 39-40 Character
|
| 162 |
+
NAME 42-71 Character
|
| 163 |
+
GSN FLAG 73-75 Character
|
| 164 |
+
HCN/CRN FLAG 77-79 Character
|
| 165 |
+
WMO ID 81-85 Character
|
| 166 |
+
"""
|
| 167 |
+
if path is None:
|
| 168 |
+
path = self.output_dir / "ghcnd-stations.txt"
|
| 169 |
+
|
| 170 |
+
if not path.exists():
|
| 171 |
+
self.download_stations()
|
| 172 |
+
|
| 173 |
+
stations = []
|
| 174 |
+
with open(path) as f:
|
| 175 |
+
for line in f:
|
| 176 |
+
if len(line.strip()) < 40:
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
station = Station(
|
| 180 |
+
id=line[0:11].strip(),
|
| 181 |
+
latitude=float(line[12:20].strip()),
|
| 182 |
+
longitude=float(line[21:30].strip()),
|
| 183 |
+
elevation=float(line[31:37].strip()) if line[31:37].strip() else 0.0,
|
| 184 |
+
state=line[38:40].strip() or None,
|
| 185 |
+
name=line[41:71].strip(),
|
| 186 |
+
gsn_flag=line[72:75].strip() or None,
|
| 187 |
+
hcn_flag=line[76:79].strip() or None,
|
| 188 |
+
wmo_id=line[80:85].strip() or None,
|
| 189 |
+
)
|
| 190 |
+
stations.append(station)
|
| 191 |
+
|
| 192 |
+
logger.info(f"Parsed {len(stations)} stations")
|
| 193 |
+
return stations
|
| 194 |
+
|
| 195 |
+
def parse_inventory(self, path: Optional[Path] = None) -> pd.DataFrame:
|
| 196 |
+
"""
|
| 197 |
+
Parse the inventory file.
|
| 198 |
+
|
| 199 |
+
Format (fixed-width):
|
| 200 |
+
ID 1-11 Character
|
| 201 |
+
LATITUDE 13-20 Real
|
| 202 |
+
LONGITUDE 22-30 Real
|
| 203 |
+
ELEMENT 32-35 Character
|
| 204 |
+
FIRSTYEAR 37-40 Integer
|
| 205 |
+
LASTYEAR 42-45 Integer
|
| 206 |
+
"""
|
| 207 |
+
if path is None:
|
| 208 |
+
path = self.output_dir / "ghcnd-inventory.txt"
|
| 209 |
+
|
| 210 |
+
if not path.exists():
|
| 211 |
+
self.download_inventory()
|
| 212 |
+
|
| 213 |
+
records = []
|
| 214 |
+
with open(path) as f:
|
| 215 |
+
for line in f:
|
| 216 |
+
if len(line.strip()) < 45:
|
| 217 |
+
continue
|
| 218 |
+
|
| 219 |
+
records.append(
|
| 220 |
+
{
|
| 221 |
+
"station_id": line[0:11].strip(),
|
| 222 |
+
"latitude": float(line[12:20].strip()),
|
| 223 |
+
"longitude": float(line[21:30].strip()),
|
| 224 |
+
"element": line[31:35].strip(),
|
| 225 |
+
"first_year": int(line[36:40].strip()),
|
| 226 |
+
"last_year": int(line[41:45].strip()),
|
| 227 |
+
}
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
df = pd.DataFrame(records)
|
| 231 |
+
logger.info(f"Parsed {len(df)} inventory records")
|
| 232 |
+
return df
|
| 233 |
+
|
| 234 |
+
def get_stations(
|
| 235 |
+
self,
|
| 236 |
+
country: Optional[str] = None,
|
| 237 |
+
min_years: int = 0,
|
| 238 |
+
elements: Optional[List[str]] = None,
|
| 239 |
+
bbox: Optional[tuple[float, float, float, float]] = None,
|
| 240 |
+
) -> List[Station]:
|
| 241 |
+
"""
|
| 242 |
+
Get stations matching criteria.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
country: 2-letter country code (first 2 chars of station ID)
|
| 246 |
+
min_years: Minimum years of data required
|
| 247 |
+
elements: Required elements (e.g., ["TMAX", "TMIN", "PRCP"])
|
| 248 |
+
bbox: Bounding box (min_lon, min_lat, max_lon, max_lat)
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
List of matching stations
|
| 252 |
+
"""
|
| 253 |
+
stations = self.parse_stations()
|
| 254 |
+
inventory = self.parse_inventory()
|
| 255 |
+
|
| 256 |
+
# Filter by country
|
| 257 |
+
if country:
|
| 258 |
+
stations = [s for s in stations if s.id.startswith(country)]
|
| 259 |
+
|
| 260 |
+
# Filter by bounding box
|
| 261 |
+
if bbox:
|
| 262 |
+
min_lon, min_lat, max_lon, max_lat = bbox
|
| 263 |
+
stations = [
|
| 264 |
+
s
|
| 265 |
+
for s in stations
|
| 266 |
+
if min_lon <= s.longitude <= max_lon and min_lat <= s.latitude <= max_lat
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
# Filter by data availability using VECTORIZED pandas operations (fast!)
|
| 270 |
+
if min_years > 0 or elements:
|
| 271 |
+
elements = elements or list(self.ELEMENTS.keys())
|
| 272 |
+
|
| 273 |
+
# Create a station ID set for fast lookup
|
| 274 |
+
station_ids = {s.id for s in stations}
|
| 275 |
+
|
| 276 |
+
# Filter inventory to only include our stations and required elements
|
| 277 |
+
inv_filtered = inventory[
|
| 278 |
+
(inventory["station_id"].isin(station_ids)) &
|
| 279 |
+
(inventory["element"].isin(elements))
|
| 280 |
+
].copy()
|
| 281 |
+
|
| 282 |
+
# Calculate years of data for each station-element combo
|
| 283 |
+
inv_filtered["years"] = inv_filtered["last_year"] - inv_filtered["first_year"]
|
| 284 |
+
|
| 285 |
+
# Group by station and check requirements
|
| 286 |
+
station_stats = inv_filtered.groupby("station_id").agg({
|
| 287 |
+
"element": "nunique", # Count unique elements
|
| 288 |
+
"years": "max" # Max years of any element
|
| 289 |
+
}).reset_index()
|
| 290 |
+
|
| 291 |
+
# Filter stations that have all required elements and enough years
|
| 292 |
+
valid_stations = station_stats[
|
| 293 |
+
(station_stats["element"] >= len(elements)) &
|
| 294 |
+
(station_stats["years"] >= min_years)
|
| 295 |
+
]["station_id"].tolist()
|
| 296 |
+
|
| 297 |
+
valid_ids = set(valid_stations)
|
| 298 |
+
stations = [s for s in stations if s.id in valid_ids]
|
| 299 |
+
|
| 300 |
+
logger.info(f"Found {len(stations)} matching stations")
|
| 301 |
+
return stations
|
| 302 |
+
|
| 303 |
+
def download_station_data(
|
| 304 |
+
self,
|
| 305 |
+
station_id: str,
|
| 306 |
+
force: bool = False,
|
| 307 |
+
) -> Path:
|
| 308 |
+
"""
|
| 309 |
+
Download data file for a single station.
|
| 310 |
+
|
| 311 |
+
The data is stored in .dly format (one file per station).
|
| 312 |
+
"""
|
| 313 |
+
# Station data is in the 'all' subdirectory as .dly.gz files
|
| 314 |
+
url = f"{self.BASE_URL}all/{station_id}.dly"
|
| 315 |
+
output_path = self.output_dir / "stations" / f"{station_id}.dly"
|
| 316 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 317 |
+
|
| 318 |
+
if output_path.exists() and not force:
|
| 319 |
+
logger.debug(f"Station data already exists: {output_path}")
|
| 320 |
+
return output_path
|
| 321 |
+
|
| 322 |
+
logger.debug(f"Downloading {station_id}")
|
| 323 |
+
|
| 324 |
+
try:
|
| 325 |
+
response = self.client.get(url)
|
| 326 |
+
response.raise_for_status()
|
| 327 |
+
output_path.write_text(response.text)
|
| 328 |
+
except httpx.HTTPStatusError:
|
| 329 |
+
# Try gzipped version
|
| 330 |
+
url_gz = f"{url}.gz"
|
| 331 |
+
response = self.client.get(url_gz)
|
| 332 |
+
response.raise_for_status()
|
| 333 |
+
|
| 334 |
+
# Decompress
|
| 335 |
+
content = gzip.decompress(response.content)
|
| 336 |
+
output_path.write_bytes(content)
|
| 337 |
+
|
| 338 |
+
return output_path
|
| 339 |
+
|
| 340 |
+
def parse_station_data(self, station_id: str) -> Generator[DailyObservation, None, None]:
|
| 341 |
+
"""
|
| 342 |
+
Parse a station's .dly file and yield observations.
|
| 343 |
+
|
| 344 |
+
Format (fixed-width, one line per station-year-month-element):
|
| 345 |
+
ID 1-11 Character
|
| 346 |
+
YEAR 12-15 Integer
|
| 347 |
+
MONTH 16-17 Integer
|
| 348 |
+
ELEMENT 18-21 Character
|
| 349 |
+
VALUE1 22-26 Integer (day 1)
|
| 350 |
+
MFLAG1 27-27 Character
|
| 351 |
+
QFLAG1 28-28 Character
|
| 352 |
+
SFLAG1 29-29 Character
|
| 353 |
+
... repeated for days 2-31
|
| 354 |
+
"""
|
| 355 |
+
path = self.output_dir / "stations" / f"{station_id}.dly"
|
| 356 |
+
if not path.exists():
|
| 357 |
+
self.download_station_data(station_id)
|
| 358 |
+
|
| 359 |
+
with open(path) as f:
|
| 360 |
+
for line in f:
|
| 361 |
+
if len(line) < 269:
|
| 362 |
+
continue
|
| 363 |
+
|
| 364 |
+
station = line[0:11].strip()
|
| 365 |
+
year = int(line[11:15])
|
| 366 |
+
month = int(line[15:17])
|
| 367 |
+
element = line[17:21].strip()
|
| 368 |
+
|
| 369 |
+
# Skip elements we don't care about
|
| 370 |
+
if element not in self.ELEMENTS:
|
| 371 |
+
continue
|
| 372 |
+
|
| 373 |
+
# Parse each day's value (31 days max)
|
| 374 |
+
for day in range(1, 32):
|
| 375 |
+
offset = 21 + (day - 1) * 8
|
| 376 |
+
value_str = line[offset : offset + 5].strip()
|
| 377 |
+
m_flag = line[offset + 5 : offset + 6].strip() or None
|
| 378 |
+
q_flag = line[offset + 6 : offset + 7].strip() or None
|
| 379 |
+
s_flag = line[offset + 7 : offset + 8].strip() or None
|
| 380 |
+
|
| 381 |
+
# -9999 indicates missing value
|
| 382 |
+
if value_str == "-9999" or not value_str:
|
| 383 |
+
continue
|
| 384 |
+
|
| 385 |
+
try:
|
| 386 |
+
date = datetime(year, month, day)
|
| 387 |
+
except ValueError:
|
| 388 |
+
# Invalid date (e.g., Feb 30)
|
| 389 |
+
continue
|
| 390 |
+
|
| 391 |
+
# Convert value (stored as tenths for most elements)
|
| 392 |
+
value = float(value_str)
|
| 393 |
+
if element in ("TMAX", "TMIN", "TAVG", "PRCP", "AWND"):
|
| 394 |
+
value /= 10.0
|
| 395 |
+
|
| 396 |
+
yield DailyObservation(
|
| 397 |
+
station_id=station,
|
| 398 |
+
date=date,
|
| 399 |
+
element=element,
|
| 400 |
+
value=value,
|
| 401 |
+
m_flag=m_flag,
|
| 402 |
+
q_flag=q_flag,
|
| 403 |
+
s_flag=s_flag,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
def station_to_dataframe(self, station_id: str) -> pd.DataFrame:
|
| 407 |
+
"""
|
| 408 |
+
Load station data as a pandas DataFrame.
|
| 409 |
+
|
| 410 |
+
Returns a DataFrame with columns for each element and a datetime index.
|
| 411 |
+
"""
|
| 412 |
+
observations = list(self.parse_station_data(station_id))
|
| 413 |
+
|
| 414 |
+
if not observations:
|
| 415 |
+
return pd.DataFrame()
|
| 416 |
+
|
| 417 |
+
# Convert to DataFrame
|
| 418 |
+
df = pd.DataFrame([vars(o) for o in observations])
|
| 419 |
+
|
| 420 |
+
# Pivot to have elements as columns
|
| 421 |
+
df = df.pivot_table(
|
| 422 |
+
index="date",
|
| 423 |
+
columns="element",
|
| 424 |
+
values="value",
|
| 425 |
+
aggfunc="first",
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
df.index = pd.to_datetime(df.index)
|
| 429 |
+
df = df.sort_index()
|
| 430 |
+
|
| 431 |
+
return df
|
| 432 |
+
|
| 433 |
+
def download_all(
|
| 434 |
+
self,
|
| 435 |
+
stations: Optional[List[Station]] = None,
|
| 436 |
+
max_stations: Optional[int] = None,
|
| 437 |
+
**filter_kwargs,
|
| 438 |
+
) -> List[Path]:
|
| 439 |
+
"""
|
| 440 |
+
Download data for multiple stations.
|
| 441 |
+
|
| 442 |
+
Args:
|
| 443 |
+
stations: List of stations to download (or use filter_kwargs)
|
| 444 |
+
max_stations: Maximum number of stations to download
|
| 445 |
+
**filter_kwargs: Arguments passed to get_stations()
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
List of paths to downloaded files
|
| 449 |
+
"""
|
| 450 |
+
if stations is None:
|
| 451 |
+
stations = self.get_stations(**filter_kwargs)
|
| 452 |
+
|
| 453 |
+
if max_stations:
|
| 454 |
+
stations = stations[:max_stations]
|
| 455 |
+
|
| 456 |
+
paths = []
|
| 457 |
+
for station in tqdm(stations, desc="Downloading stations"):
|
| 458 |
+
try:
|
| 459 |
+
path = self.download_station_data(station.id)
|
| 460 |
+
paths.append(path)
|
| 461 |
+
except Exception as e:
|
| 462 |
+
logger.warning(f"Failed to download {station.id}: {e}")
|
| 463 |
+
|
| 464 |
+
logger.success(f"Downloaded {len(paths)} station files")
|
| 465 |
+
return paths
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def main():
|
| 469 |
+
"""CLI entry point for downloading GHCN-Daily data."""
|
| 470 |
+
import argparse
|
| 471 |
+
|
| 472 |
+
parser = argparse.ArgumentParser(description="Download GHCN-Daily data")
|
| 473 |
+
parser.add_argument(
|
| 474 |
+
"--output-dir",
|
| 475 |
+
default="data/raw/ghcn_daily",
|
| 476 |
+
help="Output directory for downloaded data",
|
| 477 |
+
)
|
| 478 |
+
parser.add_argument(
|
| 479 |
+
"--country",
|
| 480 |
+
default=None,
|
| 481 |
+
help="Filter by country code (e.g., US, CA, GB)",
|
| 482 |
+
)
|
| 483 |
+
parser.add_argument(
|
| 484 |
+
"--min-years",
|
| 485 |
+
type=int,
|
| 486 |
+
default=30,
|
| 487 |
+
help="Minimum years of data required",
|
| 488 |
+
)
|
| 489 |
+
parser.add_argument(
|
| 490 |
+
"--max-stations",
|
| 491 |
+
type=int,
|
| 492 |
+
default=None,
|
| 493 |
+
help="Maximum number of stations to download",
|
| 494 |
+
)
|
| 495 |
+
parser.add_argument(
|
| 496 |
+
"--stations-only",
|
| 497 |
+
action="store_true",
|
| 498 |
+
help="Only download station metadata, not observation data",
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
args = parser.parse_args()
|
| 502 |
+
|
| 503 |
+
with GHCNDailyDownloader(output_dir=args.output_dir) as downloader:
|
| 504 |
+
# Always download metadata
|
| 505 |
+
downloader.download_stations()
|
| 506 |
+
downloader.download_inventory()
|
| 507 |
+
|
| 508 |
+
if not args.stations_only:
|
| 509 |
+
downloader.download_all(
|
| 510 |
+
country=args.country,
|
| 511 |
+
min_years=args.min_years,
|
| 512 |
+
max_stations=args.max_stations,
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
if __name__ == "__main__":
|
| 517 |
+
main()
|
data/download/ghcn_hourly.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GHCN-Hourly Data Downloader
|
| 3 |
+
|
| 4 |
+
Downloads and parses GHCN-Hourly (formerly ISD) data from NOAA NCEI.
|
| 5 |
+
https://www.ncei.noaa.gov/products/global-historical-climatology-network-hourly
|
| 6 |
+
|
| 7 |
+
This dataset includes wind, temperature, pressure, humidity, clouds, and more
|
| 8 |
+
at hourly resolution from 20,000+ stations worldwide.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import gzip
|
| 12 |
+
import json
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Generator, Optional, List, Union
|
| 17 |
+
|
| 18 |
+
import httpx
|
| 19 |
+
import pandas as pd
|
| 20 |
+
from loguru import logger
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class HourlyStation:
|
| 26 |
+
"""GHCN-Hourly station metadata."""
|
| 27 |
+
|
| 28 |
+
usaf: str # USAF station ID
|
| 29 |
+
wban: str # WBAN station ID
|
| 30 |
+
station_name: str
|
| 31 |
+
country: str
|
| 32 |
+
state: Optional[str]
|
| 33 |
+
latitude: float
|
| 34 |
+
longitude: float
|
| 35 |
+
elevation: float
|
| 36 |
+
begin_date: datetime
|
| 37 |
+
end_date: datetime
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def id(self) -> str:
|
| 41 |
+
"""Combined station ID."""
|
| 42 |
+
return f"{self.usaf}-{self.wban}"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class HourlyObservation:
|
| 47 |
+
"""Single hourly observation record."""
|
| 48 |
+
|
| 49 |
+
station_id: str
|
| 50 |
+
timestamp: datetime
|
| 51 |
+
latitude: float
|
| 52 |
+
longitude: float
|
| 53 |
+
elevation: float
|
| 54 |
+
|
| 55 |
+
# Wind
|
| 56 |
+
wind_direction: Optional[float] # degrees
|
| 57 |
+
wind_speed: Optional[float] # m/s
|
| 58 |
+
wind_gust: Optional[float] # m/s
|
| 59 |
+
|
| 60 |
+
# Temperature
|
| 61 |
+
temperature: Optional[float] # °C
|
| 62 |
+
dew_point: Optional[float] # °C
|
| 63 |
+
|
| 64 |
+
# Pressure
|
| 65 |
+
sea_level_pressure: Optional[float] # hPa
|
| 66 |
+
station_pressure: Optional[float] # hPa
|
| 67 |
+
|
| 68 |
+
# Humidity
|
| 69 |
+
relative_humidity: Optional[float] # %
|
| 70 |
+
|
| 71 |
+
# Visibility
|
| 72 |
+
visibility: Optional[float] # meters
|
| 73 |
+
|
| 74 |
+
# Precipitation
|
| 75 |
+
precipitation_1h: Optional[float] # mm
|
| 76 |
+
precipitation_6h: Optional[float] # mm
|
| 77 |
+
|
| 78 |
+
# Sky condition
|
| 79 |
+
cloud_ceiling: Optional[float] # meters
|
| 80 |
+
cloud_coverage: Optional[str] # e.g., "CLR", "FEW", "SCT", "BKN", "OVC"
|
| 81 |
+
|
| 82 |
+
# Quality
|
| 83 |
+
quality_control: str
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class GHCNHourlyDownloader:
|
| 87 |
+
"""
|
| 88 |
+
Downloads and parses GHCN-Hourly (ISD-Lite) data.
|
| 89 |
+
|
| 90 |
+
GHCN-Hourly provides sub-daily observations including wind, temperature,
|
| 91 |
+
pressure, and humidity from global surface stations.
|
| 92 |
+
|
| 93 |
+
We use the ISD-Lite format which is a simplified version containing the
|
| 94 |
+
most essential variables.
|
| 95 |
+
|
| 96 |
+
Example usage:
|
| 97 |
+
downloader = GHCNHourlyDownloader(output_dir="data/raw/ghcn_hourly")
|
| 98 |
+
downloader.download_station_list()
|
| 99 |
+
|
| 100 |
+
# Download data for specific stations and years
|
| 101 |
+
for station in downloader.get_stations(country="US", min_years=30):
|
| 102 |
+
downloader.download_station_year(station.usaf, station.wban, 2023)
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
# ISD-Lite base URL (simplified hourly format)
|
| 106 |
+
BASE_URL = "https://www.ncei.noaa.gov/pub/data/noaa/isd-lite/"
|
| 107 |
+
STATION_LIST_URL = "https://www.ncei.noaa.gov/pub/data/noaa/isd-history.csv"
|
| 108 |
+
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
output_dir: Union[str, Path] = "data/raw/ghcn_hourly",
|
| 112 |
+
timeout: float = 60.0,
|
| 113 |
+
):
|
| 114 |
+
self.output_dir = Path(output_dir)
|
| 115 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 116 |
+
self.timeout = timeout
|
| 117 |
+
self._client: Optional[httpx.Client] = None
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def client(self) -> httpx.Client:
|
| 121 |
+
"""Lazy-initialized HTTP client."""
|
| 122 |
+
if self._client is None:
|
| 123 |
+
self._client = httpx.Client(timeout=self.timeout, follow_redirects=True)
|
| 124 |
+
return self._client
|
| 125 |
+
|
| 126 |
+
def __enter__(self) -> "GHCNHourlyDownloader":
|
| 127 |
+
return self
|
| 128 |
+
|
| 129 |
+
def __exit__(self, *args) -> None:
|
| 130 |
+
if self._client:
|
| 131 |
+
self._client.close()
|
| 132 |
+
|
| 133 |
+
def download_station_list(self, force: bool = False) -> Path:
|
| 134 |
+
"""Download the station history/metadata file."""
|
| 135 |
+
output_path = self.output_dir / "isd-history.csv"
|
| 136 |
+
|
| 137 |
+
if output_path.exists() and not force:
|
| 138 |
+
logger.info(f"Station list already exists: {output_path}")
|
| 139 |
+
return output_path
|
| 140 |
+
|
| 141 |
+
logger.info(f"Downloading station list from {self.STATION_LIST_URL}")
|
| 142 |
+
response = self.client.get(self.STATION_LIST_URL)
|
| 143 |
+
response.raise_for_status()
|
| 144 |
+
|
| 145 |
+
output_path.write_text(response.text)
|
| 146 |
+
logger.success(f"Downloaded station list to {output_path}")
|
| 147 |
+
return output_path
|
| 148 |
+
|
| 149 |
+
def parse_stations(self, path: Optional[Path] = None) -> List[HourlyStation]:
|
| 150 |
+
"""Parse the station history CSV file."""
|
| 151 |
+
if path is None:
|
| 152 |
+
path = self.output_dir / "isd-history.csv"
|
| 153 |
+
|
| 154 |
+
if not path.exists():
|
| 155 |
+
self.download_station_list()
|
| 156 |
+
|
| 157 |
+
df = pd.read_csv(path, low_memory=False)
|
| 158 |
+
|
| 159 |
+
stations = []
|
| 160 |
+
for _, row in df.iterrows():
|
| 161 |
+
try:
|
| 162 |
+
# Skip stations with missing coordinates
|
| 163 |
+
if pd.isna(row.get("LAT")) or pd.isna(row.get("LON")):
|
| 164 |
+
continue
|
| 165 |
+
|
| 166 |
+
station = HourlyStation(
|
| 167 |
+
usaf=str(row["USAF"]).zfill(6),
|
| 168 |
+
wban=str(row["WBAN"]).zfill(5),
|
| 169 |
+
station_name=str(row.get("STATION NAME", "")),
|
| 170 |
+
country=str(row.get("CTRY", "")),
|
| 171 |
+
state=str(row.get("STATE", "")) if pd.notna(row.get("STATE")) else None,
|
| 172 |
+
latitude=float(row["LAT"]),
|
| 173 |
+
longitude=float(row["LON"]),
|
| 174 |
+
elevation=float(row.get("ELEV(M)", 0)) if pd.notna(row.get("ELEV(M)")) else 0.0,
|
| 175 |
+
begin_date=pd.to_datetime(str(row.get("BEGIN", "19000101")), format="%Y%m%d"),
|
| 176 |
+
end_date=pd.to_datetime(str(row.get("END", "20991231")), format="%Y%m%d"),
|
| 177 |
+
)
|
| 178 |
+
stations.append(station)
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logger.debug(f"Skipping station: {e}")
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
logger.info(f"Parsed {len(stations)} stations")
|
| 184 |
+
return stations
|
| 185 |
+
|
| 186 |
+
def get_stations(
|
| 187 |
+
self,
|
| 188 |
+
country: Optional[str] = None,
|
| 189 |
+
min_years: int = 0,
|
| 190 |
+
bbox: Optional[tuple[float, float, float, float]] = None,
|
| 191 |
+
active_only: bool = True,
|
| 192 |
+
) -> List[HourlyStation]:
|
| 193 |
+
"""
|
| 194 |
+
Get stations matching criteria.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
country: 2-letter country code
|
| 198 |
+
min_years: Minimum years of data required
|
| 199 |
+
bbox: Bounding box (min_lon, min_lat, max_lon, max_lat)
|
| 200 |
+
active_only: Only include stations with data through 2023+
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
List of matching stations
|
| 204 |
+
"""
|
| 205 |
+
stations = self.parse_stations()
|
| 206 |
+
|
| 207 |
+
if country:
|
| 208 |
+
stations = [s for s in stations if s.country == country]
|
| 209 |
+
|
| 210 |
+
if bbox:
|
| 211 |
+
min_lon, min_lat, max_lon, max_lat = bbox
|
| 212 |
+
stations = [
|
| 213 |
+
s
|
| 214 |
+
for s in stations
|
| 215 |
+
if min_lon <= s.longitude <= max_lon and min_lat <= s.latitude <= max_lat
|
| 216 |
+
]
|
| 217 |
+
|
| 218 |
+
if min_years > 0:
|
| 219 |
+
stations = [
|
| 220 |
+
s
|
| 221 |
+
for s in stations
|
| 222 |
+
if (s.end_date - s.begin_date).days / 365 >= min_years
|
| 223 |
+
]
|
| 224 |
+
|
| 225 |
+
if active_only:
|
| 226 |
+
cutoff = datetime(2023, 1, 1)
|
| 227 |
+
stations = [s for s in stations if s.end_date >= cutoff]
|
| 228 |
+
|
| 229 |
+
logger.info(f"Found {len(stations)} matching stations")
|
| 230 |
+
return stations
|
| 231 |
+
|
| 232 |
+
def download_station_year(
|
| 233 |
+
self,
|
| 234 |
+
usaf: str,
|
| 235 |
+
wban: str,
|
| 236 |
+
year: int,
|
| 237 |
+
force: bool = False,
|
| 238 |
+
) -> Optional[Path]:
|
| 239 |
+
"""
|
| 240 |
+
Download ISD-Lite data for a station-year.
|
| 241 |
+
|
| 242 |
+
ISD-Lite files are organized by year: {year}/{usaf}-{wban}-{year}.gz
|
| 243 |
+
"""
|
| 244 |
+
filename = f"{usaf}-{wban}-{year}.gz"
|
| 245 |
+
url = f"{self.BASE_URL}{year}/{filename}"
|
| 246 |
+
output_path = self.output_dir / "data" / str(year) / filename
|
| 247 |
+
|
| 248 |
+
if output_path.exists() and not force:
|
| 249 |
+
logger.debug(f"Data already exists: {output_path}")
|
| 250 |
+
return output_path
|
| 251 |
+
|
| 252 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 253 |
+
|
| 254 |
+
try:
|
| 255 |
+
response = self.client.get(url)
|
| 256 |
+
response.raise_for_status()
|
| 257 |
+
output_path.write_bytes(response.content)
|
| 258 |
+
return output_path
|
| 259 |
+
except httpx.HTTPStatusError as e:
|
| 260 |
+
if e.response.status_code == 404:
|
| 261 |
+
logger.debug(f"No data for {usaf}-{wban} in {year}")
|
| 262 |
+
return None
|
| 263 |
+
raise
|
| 264 |
+
|
| 265 |
+
def parse_isd_lite(
|
| 266 |
+
self,
|
| 267 |
+
usaf: str,
|
| 268 |
+
wban: str,
|
| 269 |
+
year: int,
|
| 270 |
+
) -> Generator[HourlyObservation, None, None]:
|
| 271 |
+
"""
|
| 272 |
+
Parse an ISD-Lite file and yield observations.
|
| 273 |
+
|
| 274 |
+
ISD-Lite format (fixed-width, space-separated):
|
| 275 |
+
Field 1: Year
|
| 276 |
+
Field 2: Month
|
| 277 |
+
Field 3: Day
|
| 278 |
+
Field 4: Hour
|
| 279 |
+
Field 5: Air Temperature (°C * 10)
|
| 280 |
+
Field 6: Dew Point Temperature (°C * 10)
|
| 281 |
+
Field 7: Sea Level Pressure (hPa * 10)
|
| 282 |
+
Field 8: Wind Direction (degrees)
|
| 283 |
+
Field 9: Wind Speed (m/s * 10)
|
| 284 |
+
Field 10: Sky Condition Total Coverage Code
|
| 285 |
+
Field 11: Liquid Precipitation Depth 1-Hour (mm * 10)
|
| 286 |
+
Field 12: Liquid Precipitation Depth 6-Hour (mm * 10)
|
| 287 |
+
|
| 288 |
+
Missing values are represented as -9999.
|
| 289 |
+
"""
|
| 290 |
+
path = self.output_dir / "data" / str(year) / f"{usaf}-{wban}-{year}.gz"
|
| 291 |
+
|
| 292 |
+
if not path.exists():
|
| 293 |
+
result = self.download_station_year(usaf, wban, year)
|
| 294 |
+
if result is None:
|
| 295 |
+
return
|
| 296 |
+
|
| 297 |
+
station_id = f"{usaf}-{wban}"
|
| 298 |
+
|
| 299 |
+
with gzip.open(path, "rt") as f:
|
| 300 |
+
for line in f:
|
| 301 |
+
parts = line.split()
|
| 302 |
+
if len(parts) < 12:
|
| 303 |
+
continue
|
| 304 |
+
|
| 305 |
+
try:
|
| 306 |
+
year_val = int(parts[0])
|
| 307 |
+
month = int(parts[1])
|
| 308 |
+
day = int(parts[2])
|
| 309 |
+
hour = int(parts[3])
|
| 310 |
+
|
| 311 |
+
timestamp = datetime(year_val, month, day, hour)
|
| 312 |
+
|
| 313 |
+
# Parse values (-9999 = missing)
|
| 314 |
+
def parse_val(idx: int, scale: float = 10.0) -> Optional[float]:
|
| 315 |
+
val = int(parts[idx])
|
| 316 |
+
return val / scale if val != -9999 else None
|
| 317 |
+
|
| 318 |
+
yield HourlyObservation(
|
| 319 |
+
station_id=station_id,
|
| 320 |
+
timestamp=timestamp,
|
| 321 |
+
latitude=0.0, # Need to lookup from station metadata
|
| 322 |
+
longitude=0.0,
|
| 323 |
+
elevation=0.0,
|
| 324 |
+
wind_direction=parse_val(7, 1.0),
|
| 325 |
+
wind_speed=parse_val(8, 10.0),
|
| 326 |
+
wind_gust=None,
|
| 327 |
+
temperature=parse_val(4, 10.0),
|
| 328 |
+
dew_point=parse_val(5, 10.0),
|
| 329 |
+
sea_level_pressure=parse_val(6, 10.0),
|
| 330 |
+
station_pressure=None,
|
| 331 |
+
relative_humidity=None, # Computed from temp/dew point
|
| 332 |
+
visibility=None,
|
| 333 |
+
precipitation_1h=parse_val(10, 10.0),
|
| 334 |
+
precipitation_6h=parse_val(11, 10.0),
|
| 335 |
+
cloud_ceiling=None,
|
| 336 |
+
cloud_coverage=str(int(parts[9])) if int(parts[9]) != -9999 else None,
|
| 337 |
+
quality_control="",
|
| 338 |
+
)
|
| 339 |
+
except (ValueError, IndexError) as e:
|
| 340 |
+
logger.debug(f"Parse error: {e}")
|
| 341 |
+
continue
|
| 342 |
+
|
| 343 |
+
def station_year_to_dataframe(
|
| 344 |
+
self,
|
| 345 |
+
usaf: str,
|
| 346 |
+
wban: str,
|
| 347 |
+
year: int,
|
| 348 |
+
) -> pd.DataFrame:
|
| 349 |
+
"""Load station-year data as a pandas DataFrame."""
|
| 350 |
+
observations = list(self.parse_isd_lite(usaf, wban, year))
|
| 351 |
+
|
| 352 |
+
if not observations:
|
| 353 |
+
return pd.DataFrame()
|
| 354 |
+
|
| 355 |
+
df = pd.DataFrame([vars(o) for o in observations])
|
| 356 |
+
df = df.set_index("timestamp").sort_index()
|
| 357 |
+
|
| 358 |
+
return df
|
| 359 |
+
|
| 360 |
+
def download_station_range(
|
| 361 |
+
self,
|
| 362 |
+
usaf: str,
|
| 363 |
+
wban: str,
|
| 364 |
+
start_year: int,
|
| 365 |
+
end_year: int,
|
| 366 |
+
) -> List[Path]:
|
| 367 |
+
"""Download multiple years of data for a station."""
|
| 368 |
+
paths = []
|
| 369 |
+
for year in range(start_year, end_year + 1):
|
| 370 |
+
result = self.download_station_year(usaf, wban, year)
|
| 371 |
+
if result:
|
| 372 |
+
paths.append(result)
|
| 373 |
+
return paths
|
| 374 |
+
|
| 375 |
+
def download_all(
|
| 376 |
+
self,
|
| 377 |
+
stations: Optional[List[HourlyStation]] = None,
|
| 378 |
+
years: Optional[List[int]] = None,
|
| 379 |
+
max_stations: Optional[int] = None,
|
| 380 |
+
**filter_kwargs,
|
| 381 |
+
) -> int:
|
| 382 |
+
"""
|
| 383 |
+
Download data for multiple stations and years.
|
| 384 |
+
|
| 385 |
+
Returns count of files downloaded.
|
| 386 |
+
"""
|
| 387 |
+
if stations is None:
|
| 388 |
+
stations = self.get_stations(**filter_kwargs)
|
| 389 |
+
|
| 390 |
+
if max_stations:
|
| 391 |
+
stations = stations[:max_stations]
|
| 392 |
+
|
| 393 |
+
if years is None:
|
| 394 |
+
years = list(range(2000, 2024))
|
| 395 |
+
|
| 396 |
+
count = 0
|
| 397 |
+
for station in tqdm(stations, desc="Downloading stations"):
|
| 398 |
+
for year in years:
|
| 399 |
+
try:
|
| 400 |
+
result = self.download_station_year(station.usaf, station.wban, year)
|
| 401 |
+
if result:
|
| 402 |
+
count += 1
|
| 403 |
+
except Exception as e:
|
| 404 |
+
logger.warning(f"Failed to download {station.id}/{year}: {e}")
|
| 405 |
+
|
| 406 |
+
logger.success(f"Downloaded {count} station-year files")
|
| 407 |
+
return count
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def main():
|
| 411 |
+
"""CLI entry point for downloading GHCN-Hourly data."""
|
| 412 |
+
import argparse
|
| 413 |
+
|
| 414 |
+
parser = argparse.ArgumentParser(description="Download GHCN-Hourly (ISD-Lite) data")
|
| 415 |
+
parser.add_argument(
|
| 416 |
+
"--output-dir",
|
| 417 |
+
default="data/raw/ghcn_hourly",
|
| 418 |
+
help="Output directory for downloaded data",
|
| 419 |
+
)
|
| 420 |
+
parser.add_argument(
|
| 421 |
+
"--country",
|
| 422 |
+
default=None,
|
| 423 |
+
help="Filter by country code (e.g., US, CA, GB)",
|
| 424 |
+
)
|
| 425 |
+
parser.add_argument(
|
| 426 |
+
"--min-years",
|
| 427 |
+
type=int,
|
| 428 |
+
default=20,
|
| 429 |
+
help="Minimum years of data required",
|
| 430 |
+
)
|
| 431 |
+
parser.add_argument(
|
| 432 |
+
"--max-stations",
|
| 433 |
+
type=int,
|
| 434 |
+
default=None,
|
| 435 |
+
help="Maximum number of stations to download",
|
| 436 |
+
)
|
| 437 |
+
parser.add_argument(
|
| 438 |
+
"--start-year",
|
| 439 |
+
type=int,
|
| 440 |
+
default=2000,
|
| 441 |
+
help="Start year for data download",
|
| 442 |
+
)
|
| 443 |
+
parser.add_argument(
|
| 444 |
+
"--end-year",
|
| 445 |
+
type=int,
|
| 446 |
+
default=2023,
|
| 447 |
+
help="End year for data download",
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
args = parser.parse_args()
|
| 451 |
+
|
| 452 |
+
with GHCNHourlyDownloader(output_dir=args.output_dir) as downloader:
|
| 453 |
+
downloader.download_station_list()
|
| 454 |
+
|
| 455 |
+
years = list(range(args.start_year, args.end_year + 1))
|
| 456 |
+
downloader.download_all(
|
| 457 |
+
country=args.country,
|
| 458 |
+
min_years=args.min_years,
|
| 459 |
+
max_stations=args.max_stations,
|
| 460 |
+
years=years,
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
if __name__ == "__main__":
|
| 465 |
+
main()
|
data/loaders/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch DataLoaders for LILITH."""
|
| 2 |
+
|
| 3 |
+
from data.loaders.station_dataset import StationDataset, StationDataModule
|
| 4 |
+
from data.loaders.forecast_dataset import ForecastDataset
|
| 5 |
+
|
| 6 |
+
__all__ = ["StationDataset", "StationDataModule", "ForecastDataset"]
|
data/loaders/forecast_dataset.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Forecast Dataset for LILITH.
|
| 3 |
+
|
| 4 |
+
Provides data loading optimized for multi-station forecasting
|
| 5 |
+
with graph-based models.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import torch
|
| 14 |
+
from torch.utils.data import Dataset
|
| 15 |
+
from loguru import logger
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ForecastDataset(Dataset):
|
| 19 |
+
"""
|
| 20 |
+
Dataset for graph-based multi-station forecasting.
|
| 21 |
+
|
| 22 |
+
Instead of loading single stations, this dataset loads data for
|
| 23 |
+
multiple stations simultaneously, suitable for GNN-based models.
|
| 24 |
+
|
| 25 |
+
Each sample contains:
|
| 26 |
+
- Observations from N stations for the input period
|
| 27 |
+
- Targets for N stations for the forecast period
|
| 28 |
+
- Station coordinates and connectivity graph
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
data_dir: Union[str, Path],
|
| 34 |
+
sequence_length: int = 30,
|
| 35 |
+
forecast_length: int = 14,
|
| 36 |
+
max_stations: int = 500,
|
| 37 |
+
spatial_radius: float = 5.0, # degrees
|
| 38 |
+
target_variables: Optional[List[str]] = None,
|
| 39 |
+
start_date: Optional[str] = None,
|
| 40 |
+
end_date: Optional[str] = None,
|
| 41 |
+
seed: int = 42,
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Initialize the forecast dataset.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
data_dir: Directory with processed Parquet files
|
| 48 |
+
sequence_length: Days of input history
|
| 49 |
+
forecast_length: Days to forecast
|
| 50 |
+
max_stations: Maximum stations per sample
|
| 51 |
+
spatial_radius: Radius in degrees for station sampling
|
| 52 |
+
target_variables: Variables to forecast
|
| 53 |
+
start_date: Start date for data (YYYY-MM-DD)
|
| 54 |
+
end_date: End date for data (YYYY-MM-DD)
|
| 55 |
+
seed: Random seed for reproducibility
|
| 56 |
+
"""
|
| 57 |
+
self.data_dir = Path(data_dir)
|
| 58 |
+
self.sequence_length = sequence_length
|
| 59 |
+
self.forecast_length = forecast_length
|
| 60 |
+
self.total_length = sequence_length + forecast_length
|
| 61 |
+
self.max_stations = max_stations
|
| 62 |
+
self.spatial_radius = spatial_radius
|
| 63 |
+
self.target_variables = target_variables or ["TMAX", "TMIN", "PRCP"]
|
| 64 |
+
self.seed = seed
|
| 65 |
+
|
| 66 |
+
self.rng = np.random.default_rng(seed)
|
| 67 |
+
|
| 68 |
+
# Load station metadata
|
| 69 |
+
self.stations = pd.read_parquet(self.data_dir / "stations.parquet")
|
| 70 |
+
|
| 71 |
+
# Parse date range
|
| 72 |
+
self.start_date = pd.Timestamp(start_date) if start_date else pd.Timestamp("2000-01-01")
|
| 73 |
+
self.end_date = pd.Timestamp(end_date) if end_date else pd.Timestamp("2023-12-31")
|
| 74 |
+
|
| 75 |
+
# Build date index
|
| 76 |
+
self.dates = pd.date_range(
|
| 77 |
+
self.start_date,
|
| 78 |
+
self.end_date - pd.Timedelta(days=self.total_length),
|
| 79 |
+
freq="D",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Build spatial clusters for efficient sampling
|
| 83 |
+
self._build_spatial_clusters()
|
| 84 |
+
|
| 85 |
+
# Cache for loaded data
|
| 86 |
+
self._data_cache: Dict[int, pd.DataFrame] = {}
|
| 87 |
+
|
| 88 |
+
logger.info(
|
| 89 |
+
f"ForecastDataset: {len(self.dates)} dates, "
|
| 90 |
+
f"{len(self.stations)} stations, {len(self.clusters)} clusters"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def _build_spatial_clusters(self) -> None:
|
| 94 |
+
"""
|
| 95 |
+
Build spatial clusters of stations for efficient sampling.
|
| 96 |
+
|
| 97 |
+
Groups stations into overlapping clusters based on spatial proximity.
|
| 98 |
+
"""
|
| 99 |
+
self.clusters = []
|
| 100 |
+
|
| 101 |
+
# Grid-based clustering
|
| 102 |
+
lat_bins = np.arange(-90, 90, self.spatial_radius * 2)
|
| 103 |
+
lon_bins = np.arange(-180, 180, self.spatial_radius * 2)
|
| 104 |
+
|
| 105 |
+
for lat in lat_bins:
|
| 106 |
+
for lon in lon_bins:
|
| 107 |
+
# Find stations in this grid cell (with overlap)
|
| 108 |
+
mask = (
|
| 109 |
+
(self.stations["latitude"] >= lat - self.spatial_radius) &
|
| 110 |
+
(self.stations["latitude"] < lat + self.spatial_radius * 3) &
|
| 111 |
+
(self.stations["longitude"] >= lon - self.spatial_radius) &
|
| 112 |
+
(self.stations["longitude"] < lon + self.spatial_radius * 3)
|
| 113 |
+
)
|
| 114 |
+
cluster_stations = self.stations[mask]["station_id"].tolist()
|
| 115 |
+
|
| 116 |
+
if len(cluster_stations) >= 10: # Minimum cluster size
|
| 117 |
+
self.clusters.append({
|
| 118 |
+
"center_lat": lat + self.spatial_radius,
|
| 119 |
+
"center_lon": lon + self.spatial_radius,
|
| 120 |
+
"station_ids": cluster_stations,
|
| 121 |
+
})
|
| 122 |
+
|
| 123 |
+
def _load_data_for_date(self, date: pd.Timestamp) -> pd.DataFrame:
|
| 124 |
+
"""Load data for a specific date range, with caching."""
|
| 125 |
+
year = date.year
|
| 126 |
+
end_year = (date + pd.Timedelta(days=self.total_length)).year
|
| 127 |
+
|
| 128 |
+
# Load required years
|
| 129 |
+
dfs = []
|
| 130 |
+
for y in range(year, end_year + 1):
|
| 131 |
+
if y in self._data_cache:
|
| 132 |
+
dfs.append(self._data_cache[y])
|
| 133 |
+
else:
|
| 134 |
+
year_file = self.data_dir / f"observations_{y}.parquet"
|
| 135 |
+
if year_file.exists():
|
| 136 |
+
df = pd.read_parquet(year_file)
|
| 137 |
+
self._data_cache[y] = df
|
| 138 |
+
dfs.append(df)
|
| 139 |
+
|
| 140 |
+
if not dfs:
|
| 141 |
+
return pd.DataFrame()
|
| 142 |
+
|
| 143 |
+
return pd.concat(dfs)
|
| 144 |
+
|
| 145 |
+
def _build_station_graph(
|
| 146 |
+
self,
|
| 147 |
+
station_coords: np.ndarray,
|
| 148 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 149 |
+
"""
|
| 150 |
+
Build adjacency information for stations.
|
| 151 |
+
|
| 152 |
+
Returns edge_index and edge_attr for PyTorch Geometric.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
station_coords: (N, 3) array of [lat, lon, elev]
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
edge_index: (2, E) source and target node indices
|
| 159 |
+
edge_attr: (E, 1) edge distances
|
| 160 |
+
"""
|
| 161 |
+
n_stations = len(station_coords)
|
| 162 |
+
edges_src = []
|
| 163 |
+
edges_dst = []
|
| 164 |
+
edge_weights = []
|
| 165 |
+
|
| 166 |
+
# Connect stations within spatial radius
|
| 167 |
+
for i in range(n_stations):
|
| 168 |
+
for j in range(i + 1, n_stations):
|
| 169 |
+
# Calculate distance
|
| 170 |
+
dlat = station_coords[i, 0] - station_coords[j, 0]
|
| 171 |
+
dlon = station_coords[i, 1] - station_coords[j, 1]
|
| 172 |
+
dist = np.sqrt(dlat**2 + dlon**2)
|
| 173 |
+
|
| 174 |
+
if dist < self.spatial_radius:
|
| 175 |
+
# Bidirectional edges
|
| 176 |
+
edges_src.extend([i, j])
|
| 177 |
+
edges_dst.extend([j, i])
|
| 178 |
+
edge_weights.extend([dist, dist])
|
| 179 |
+
|
| 180 |
+
if not edges_src:
|
| 181 |
+
# Fallback: connect to k nearest neighbors
|
| 182 |
+
from scipy.spatial import KDTree
|
| 183 |
+
|
| 184 |
+
tree = KDTree(station_coords[:, :2])
|
| 185 |
+
for i in range(n_stations):
|
| 186 |
+
_, neighbors = tree.query(station_coords[i, :2], k=min(5, n_stations))
|
| 187 |
+
for j in neighbors:
|
| 188 |
+
if i != j:
|
| 189 |
+
dist = np.linalg.norm(station_coords[i, :2] - station_coords[j, :2])
|
| 190 |
+
edges_src.append(i)
|
| 191 |
+
edges_dst.append(j)
|
| 192 |
+
edge_weights.append(dist)
|
| 193 |
+
|
| 194 |
+
edge_index = np.array([edges_src, edges_dst], dtype=np.int64)
|
| 195 |
+
edge_attr = np.array(edge_weights, dtype=np.float32).reshape(-1, 1)
|
| 196 |
+
|
| 197 |
+
return edge_index, edge_attr
|
| 198 |
+
|
| 199 |
+
def __len__(self) -> int:
|
| 200 |
+
return len(self.dates) * len(self.clusters)
|
| 201 |
+
|
| 202 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 203 |
+
"""
|
| 204 |
+
Get a multi-station sample.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
Dict with keys:
|
| 208 |
+
- node_features: (N, seq_len, F) station observations
|
| 209 |
+
- node_coords: (N, 3) lat/lon/elev
|
| 210 |
+
- edge_index: (2, E) graph connectivity
|
| 211 |
+
- edge_attr: (E, 1) edge weights
|
| 212 |
+
- target_features: (N, forecast_len, T) targets
|
| 213 |
+
- mask: (N, seq_len + forecast_len) valid mask
|
| 214 |
+
"""
|
| 215 |
+
# Decode index
|
| 216 |
+
date_idx = idx // len(self.clusters)
|
| 217 |
+
cluster_idx = idx % len(self.clusters)
|
| 218 |
+
|
| 219 |
+
date = self.dates[date_idx]
|
| 220 |
+
cluster = self.clusters[cluster_idx]
|
| 221 |
+
|
| 222 |
+
# Sample stations from cluster
|
| 223 |
+
station_ids = cluster["station_ids"]
|
| 224 |
+
if len(station_ids) > self.max_stations:
|
| 225 |
+
station_ids = self.rng.choice(station_ids, self.max_stations, replace=False).tolist()
|
| 226 |
+
|
| 227 |
+
n_stations = len(station_ids)
|
| 228 |
+
|
| 229 |
+
# Load data
|
| 230 |
+
data = self._load_data_for_date(date)
|
| 231 |
+
if data.empty:
|
| 232 |
+
return self._empty_sample(n_stations)
|
| 233 |
+
|
| 234 |
+
# Filter to selected stations and date range
|
| 235 |
+
end_date = date + pd.Timedelta(days=self.total_length - 1)
|
| 236 |
+
mask = (
|
| 237 |
+
data["station_id"].isin(station_ids) &
|
| 238 |
+
(data.index >= date) &
|
| 239 |
+
(data.index <= end_date)
|
| 240 |
+
)
|
| 241 |
+
data = data[mask]
|
| 242 |
+
|
| 243 |
+
# Prepare feature arrays
|
| 244 |
+
feature_cols = [c for c in self.target_variables if c in data.columns]
|
| 245 |
+
n_features = len(feature_cols)
|
| 246 |
+
|
| 247 |
+
node_features = np.zeros((n_stations, self.sequence_length, n_features), dtype=np.float32)
|
| 248 |
+
target_features = np.zeros((n_stations, self.forecast_length, n_features), dtype=np.float32)
|
| 249 |
+
node_coords = np.zeros((n_stations, 3), dtype=np.float32)
|
| 250 |
+
valid_mask = np.zeros((n_stations, self.total_length), dtype=bool)
|
| 251 |
+
|
| 252 |
+
# Fill in data for each station
|
| 253 |
+
for i, station_id in enumerate(station_ids):
|
| 254 |
+
station_data = data[data["station_id"] == station_id].sort_index()
|
| 255 |
+
|
| 256 |
+
# Get station coords
|
| 257 |
+
station_meta = self.stations[self.stations["station_id"] == station_id]
|
| 258 |
+
if not station_meta.empty:
|
| 259 |
+
node_coords[i] = [
|
| 260 |
+
station_meta.iloc[0]["latitude"],
|
| 261 |
+
station_meta.iloc[0]["longitude"],
|
| 262 |
+
station_meta.iloc[0].get("elevation", 0),
|
| 263 |
+
]
|
| 264 |
+
|
| 265 |
+
# Fill input sequence
|
| 266 |
+
for j, d in enumerate(pd.date_range(date, periods=self.sequence_length, freq="D")):
|
| 267 |
+
if d in station_data.index:
|
| 268 |
+
row = station_data.loc[d]
|
| 269 |
+
if isinstance(row, pd.DataFrame):
|
| 270 |
+
row = row.iloc[0]
|
| 271 |
+
for k, col in enumerate(feature_cols):
|
| 272 |
+
val = row.get(col, np.nan)
|
| 273 |
+
if not pd.isna(val):
|
| 274 |
+
node_features[i, j, k] = val
|
| 275 |
+
valid_mask[i, j] = True
|
| 276 |
+
|
| 277 |
+
# Fill target sequence
|
| 278 |
+
target_start = date + pd.Timedelta(days=self.sequence_length)
|
| 279 |
+
for j, d in enumerate(pd.date_range(target_start, periods=self.forecast_length, freq="D")):
|
| 280 |
+
if d in station_data.index:
|
| 281 |
+
row = station_data.loc[d]
|
| 282 |
+
if isinstance(row, pd.DataFrame):
|
| 283 |
+
row = row.iloc[0]
|
| 284 |
+
for k, col in enumerate(feature_cols):
|
| 285 |
+
val = row.get(col, np.nan)
|
| 286 |
+
if not pd.isna(val):
|
| 287 |
+
target_features[i, j, k] = val
|
| 288 |
+
valid_mask[i, self.sequence_length + j] = True
|
| 289 |
+
|
| 290 |
+
# Build graph
|
| 291 |
+
edge_index, edge_attr = self._build_station_graph(node_coords)
|
| 292 |
+
|
| 293 |
+
# Replace NaN with 0 (mask indicates valid values)
|
| 294 |
+
node_features = np.nan_to_num(node_features, nan=0.0)
|
| 295 |
+
target_features = np.nan_to_num(target_features, nan=0.0)
|
| 296 |
+
|
| 297 |
+
return {
|
| 298 |
+
"node_features": torch.from_numpy(node_features),
|
| 299 |
+
"node_coords": torch.from_numpy(node_coords),
|
| 300 |
+
"edge_index": torch.from_numpy(edge_index),
|
| 301 |
+
"edge_attr": torch.from_numpy(edge_attr),
|
| 302 |
+
"target_features": torch.from_numpy(target_features),
|
| 303 |
+
"mask": torch.from_numpy(valid_mask),
|
| 304 |
+
"n_stations": n_stations,
|
| 305 |
+
"date": str(date.date()),
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
def _empty_sample(self, n_stations: int) -> Dict[str, torch.Tensor]:
|
| 309 |
+
"""Return an empty sample for error cases."""
|
| 310 |
+
return {
|
| 311 |
+
"node_features": torch.zeros(n_stations, self.sequence_length, len(self.target_variables)),
|
| 312 |
+
"node_coords": torch.zeros(n_stations, 3),
|
| 313 |
+
"edge_index": torch.zeros(2, 0, dtype=torch.long),
|
| 314 |
+
"edge_attr": torch.zeros(0, 1),
|
| 315 |
+
"target_features": torch.zeros(n_stations, self.forecast_length, len(self.target_variables)),
|
| 316 |
+
"mask": torch.zeros(n_stations, self.total_length, dtype=torch.bool),
|
| 317 |
+
"n_stations": n_stations,
|
| 318 |
+
"date": "",
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def collate_variable_graphs(batch: List[Dict]) -> Dict[str, torch.Tensor]:
|
| 323 |
+
"""
|
| 324 |
+
Custom collate function for variable-size graphs.
|
| 325 |
+
|
| 326 |
+
Combines multiple samples into a single batched graph.
|
| 327 |
+
"""
|
| 328 |
+
# Stack fixed-size tensors
|
| 329 |
+
node_features = torch.cat([b["node_features"] for b in batch], dim=0)
|
| 330 |
+
node_coords = torch.cat([b["node_coords"] for b in batch], dim=0)
|
| 331 |
+
target_features = torch.cat([b["target_features"] for b in batch], dim=0)
|
| 332 |
+
masks = torch.cat([b["mask"] for b in batch], dim=0)
|
| 333 |
+
|
| 334 |
+
# Combine edge indices with offsets
|
| 335 |
+
edge_indices = []
|
| 336 |
+
edge_attrs = []
|
| 337 |
+
offset = 0
|
| 338 |
+
|
| 339 |
+
for b in batch:
|
| 340 |
+
edge_index = b["edge_index"]
|
| 341 |
+
if edge_index.size(1) > 0:
|
| 342 |
+
edge_indices.append(edge_index + offset)
|
| 343 |
+
edge_attrs.append(b["edge_attr"])
|
| 344 |
+
offset += b["n_stations"]
|
| 345 |
+
|
| 346 |
+
if edge_indices:
|
| 347 |
+
edge_index = torch.cat(edge_indices, dim=1)
|
| 348 |
+
edge_attr = torch.cat(edge_attrs, dim=0)
|
| 349 |
+
else:
|
| 350 |
+
edge_index = torch.zeros(2, 0, dtype=torch.long)
|
| 351 |
+
edge_attr = torch.zeros(0, 1)
|
| 352 |
+
|
| 353 |
+
# Batch indices for graph batching
|
| 354 |
+
batch_idx = torch.cat([
|
| 355 |
+
torch.full((b["n_stations"],), i, dtype=torch.long)
|
| 356 |
+
for i, b in enumerate(batch)
|
| 357 |
+
])
|
| 358 |
+
|
| 359 |
+
return {
|
| 360 |
+
"node_features": node_features,
|
| 361 |
+
"node_coords": node_coords,
|
| 362 |
+
"edge_index": edge_index,
|
| 363 |
+
"edge_attr": edge_attr,
|
| 364 |
+
"target_features": target_features,
|
| 365 |
+
"mask": masks,
|
| 366 |
+
"batch": batch_idx,
|
| 367 |
+
}
|
data/loaders/station_dataset.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Station-based PyTorch Dataset for LILITH.
|
| 3 |
+
|
| 4 |
+
Provides efficient data loading for station observations with support for:
|
| 5 |
+
- Sequence-based loading for temporal models
|
| 6 |
+
- Multi-station batching for graph-based models
|
| 7 |
+
- Lazy loading for large datasets
|
| 8 |
+
- Train/val/test splitting
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Optional, Tuple, Dict, Any, List, Union
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
+
import torch
|
| 18 |
+
from torch.utils.data import Dataset, DataLoader
|
| 19 |
+
from loguru import logger
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class StationSample:
|
| 24 |
+
"""A single training sample from a station."""
|
| 25 |
+
|
| 26 |
+
station_id: str
|
| 27 |
+
latitude: float
|
| 28 |
+
longitude: float
|
| 29 |
+
elevation: float
|
| 30 |
+
|
| 31 |
+
# Input sequence
|
| 32 |
+
input_features: torch.Tensor # Shape: (seq_len, n_features)
|
| 33 |
+
input_mask: torch.Tensor # Shape: (seq_len,) - True for valid values
|
| 34 |
+
|
| 35 |
+
# Target sequence (for forecasting)
|
| 36 |
+
target_features: torch.Tensor # Shape: (forecast_len, n_targets)
|
| 37 |
+
target_mask: torch.Tensor # Shape: (forecast_len,)
|
| 38 |
+
|
| 39 |
+
# Timestamps
|
| 40 |
+
input_timestamps: np.ndarray
|
| 41 |
+
target_timestamps: np.ndarray
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class StationDataset(Dataset):
|
| 45 |
+
"""
|
| 46 |
+
PyTorch Dataset for station-based weather data.
|
| 47 |
+
|
| 48 |
+
Loads sequences of observations from individual stations for
|
| 49 |
+
training temporal forecasting models.
|
| 50 |
+
|
| 51 |
+
Example usage:
|
| 52 |
+
dataset = StationDataset(
|
| 53 |
+
data_dir="data/storage/parquet",
|
| 54 |
+
sequence_length=365,
|
| 55 |
+
forecast_length=90,
|
| 56 |
+
target_variables=["TMAX", "TMIN", "PRCP"],
|
| 57 |
+
)
|
| 58 |
+
sample = dataset[0]
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
data_dir: Union[str, Path],
|
| 64 |
+
sequence_length: int = 365,
|
| 65 |
+
forecast_length: int = 90,
|
| 66 |
+
target_variables: Optional[List[str]] = None,
|
| 67 |
+
input_variables: Optional[List[str]] = None,
|
| 68 |
+
start_year: Optional[int] = None,
|
| 69 |
+
end_year: Optional[int] = None,
|
| 70 |
+
station_ids: Optional[List[str]] = None,
|
| 71 |
+
min_valid_ratio: float = 0.8,
|
| 72 |
+
normalize: bool = True,
|
| 73 |
+
cache_in_memory: bool = False,
|
| 74 |
+
):
|
| 75 |
+
"""
|
| 76 |
+
Initialize the dataset.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
data_dir: Directory containing processed Parquet files
|
| 80 |
+
sequence_length: Number of days in input sequence
|
| 81 |
+
forecast_length: Number of days to forecast
|
| 82 |
+
target_variables: Variables to predict (default: TMAX, TMIN, PRCP)
|
| 83 |
+
input_variables: Variables to use as input (default: all available)
|
| 84 |
+
start_year: Start year for data (inclusive)
|
| 85 |
+
end_year: End year for data (inclusive)
|
| 86 |
+
station_ids: Specific stations to include (default: all)
|
| 87 |
+
min_valid_ratio: Minimum ratio of valid values in a sequence
|
| 88 |
+
normalize: Whether data is already normalized
|
| 89 |
+
cache_in_memory: Load all data into memory (faster, more RAM)
|
| 90 |
+
"""
|
| 91 |
+
self.data_dir = Path(data_dir)
|
| 92 |
+
self.sequence_length = sequence_length
|
| 93 |
+
self.forecast_length = forecast_length
|
| 94 |
+
self.total_length = sequence_length + forecast_length
|
| 95 |
+
self.min_valid_ratio = min_valid_ratio
|
| 96 |
+
self.normalize = normalize
|
| 97 |
+
self.cache_in_memory = cache_in_memory
|
| 98 |
+
|
| 99 |
+
# Default variables
|
| 100 |
+
self.target_variables = target_variables or ["TMAX", "TMIN", "PRCP"]
|
| 101 |
+
self.input_variables = input_variables
|
| 102 |
+
|
| 103 |
+
# Load station metadata
|
| 104 |
+
self.stations = self._load_stations()
|
| 105 |
+
|
| 106 |
+
# Filter stations if specified
|
| 107 |
+
if station_ids:
|
| 108 |
+
self.stations = self.stations[self.stations["station_id"].isin(station_ids)]
|
| 109 |
+
|
| 110 |
+
# Build index of valid samples
|
| 111 |
+
self.samples = self._build_sample_index(start_year, end_year)
|
| 112 |
+
|
| 113 |
+
# Cache for data
|
| 114 |
+
self._cache: Dict[str, pd.DataFrame] = {}
|
| 115 |
+
|
| 116 |
+
logger.info(
|
| 117 |
+
f"StationDataset initialized: {len(self.stations)} stations, "
|
| 118 |
+
f"{len(self.samples)} samples"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def _load_stations(self) -> pd.DataFrame:
|
| 122 |
+
"""Load station metadata."""
|
| 123 |
+
stations_path = self.data_dir / "stations.parquet"
|
| 124 |
+
if not stations_path.exists():
|
| 125 |
+
raise FileNotFoundError(f"Station metadata not found: {stations_path}")
|
| 126 |
+
|
| 127 |
+
return pd.read_parquet(stations_path)
|
| 128 |
+
|
| 129 |
+
def _build_sample_index(
|
| 130 |
+
self,
|
| 131 |
+
start_year: Optional[int],
|
| 132 |
+
end_year: Optional[int],
|
| 133 |
+
) -> List[Tuple[str, pd.Timestamp]]:
|
| 134 |
+
"""
|
| 135 |
+
Build an index of valid training samples.
|
| 136 |
+
|
| 137 |
+
Returns list of (station_id, start_date) tuples.
|
| 138 |
+
"""
|
| 139 |
+
samples = []
|
| 140 |
+
|
| 141 |
+
# Find available year files
|
| 142 |
+
year_files = sorted(self.data_dir.glob("observations_*.parquet"))
|
| 143 |
+
|
| 144 |
+
for year_file in year_files:
|
| 145 |
+
year = int(year_file.stem.split("_")[1])
|
| 146 |
+
|
| 147 |
+
# Filter by year range
|
| 148 |
+
if start_year and year < start_year:
|
| 149 |
+
continue
|
| 150 |
+
if end_year and year > end_year:
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
# Load year data
|
| 154 |
+
df = pd.read_parquet(year_file)
|
| 155 |
+
|
| 156 |
+
# Group by station
|
| 157 |
+
for station_id, station_data in df.groupby("station_id"):
|
| 158 |
+
# Check if station has enough data
|
| 159 |
+
if len(station_data) < self.total_length:
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
# Find valid sequence start points
|
| 163 |
+
# (where we have enough consecutive data)
|
| 164 |
+
dates = station_data.index.sort_values()
|
| 165 |
+
|
| 166 |
+
for i in range(len(dates) - self.total_length + 1):
|
| 167 |
+
start_date = dates[i]
|
| 168 |
+
end_date = dates[i + self.total_length - 1]
|
| 169 |
+
|
| 170 |
+
# Check for gaps (should be consecutive days)
|
| 171 |
+
expected_days = self.total_length
|
| 172 |
+
actual_days = (end_date - start_date).days + 1
|
| 173 |
+
|
| 174 |
+
if actual_days == expected_days:
|
| 175 |
+
# Check valid ratio
|
| 176 |
+
sample_data = station_data.loc[start_date:end_date]
|
| 177 |
+
target_cols = [c for c in self.target_variables if c in sample_data.columns]
|
| 178 |
+
valid_ratio = sample_data[target_cols].notna().mean().mean()
|
| 179 |
+
|
| 180 |
+
if valid_ratio >= self.min_valid_ratio:
|
| 181 |
+
samples.append((station_id, start_date))
|
| 182 |
+
|
| 183 |
+
return samples
|
| 184 |
+
|
| 185 |
+
def _load_station_data(self, station_id: str, year: int) -> pd.DataFrame:
|
| 186 |
+
"""Load data for a specific station and year."""
|
| 187 |
+
cache_key = f"{station_id}_{year}"
|
| 188 |
+
|
| 189 |
+
if cache_key in self._cache:
|
| 190 |
+
return self._cache[cache_key]
|
| 191 |
+
|
| 192 |
+
year_file = self.data_dir / f"observations_{year}.parquet"
|
| 193 |
+
if not year_file.exists():
|
| 194 |
+
return pd.DataFrame()
|
| 195 |
+
|
| 196 |
+
df = pd.read_parquet(year_file)
|
| 197 |
+
station_data = df[df["station_id"] == station_id].sort_index()
|
| 198 |
+
|
| 199 |
+
if self.cache_in_memory:
|
| 200 |
+
self._cache[cache_key] = station_data
|
| 201 |
+
|
| 202 |
+
return station_data
|
| 203 |
+
|
| 204 |
+
def __len__(self) -> int:
|
| 205 |
+
return len(self.samples)
|
| 206 |
+
|
| 207 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 208 |
+
"""
|
| 209 |
+
Get a single sample.
|
| 210 |
+
|
| 211 |
+
Returns dict with keys:
|
| 212 |
+
- input_features: (seq_len, n_features)
|
| 213 |
+
- input_mask: (seq_len,)
|
| 214 |
+
- target_features: (forecast_len, n_targets)
|
| 215 |
+
- target_mask: (forecast_len,)
|
| 216 |
+
- station_coords: (3,) - [lat, lon, elev]
|
| 217 |
+
- timestamps: (total_len,)
|
| 218 |
+
"""
|
| 219 |
+
station_id, start_date = self.samples[idx]
|
| 220 |
+
year = start_date.year
|
| 221 |
+
|
| 222 |
+
# Load data (may span two years)
|
| 223 |
+
data = self._load_station_data(station_id, year)
|
| 224 |
+
if year + 1 <= 2023: # Check for year boundary
|
| 225 |
+
next_year_data = self._load_station_data(station_id, year + 1)
|
| 226 |
+
if not next_year_data.empty:
|
| 227 |
+
data = pd.concat([data, next_year_data])
|
| 228 |
+
|
| 229 |
+
# Extract sequence
|
| 230 |
+
end_date = start_date + pd.Timedelta(days=self.total_length - 1)
|
| 231 |
+
sequence = data.loc[start_date:end_date]
|
| 232 |
+
|
| 233 |
+
if len(sequence) < self.total_length:
|
| 234 |
+
# Pad if necessary
|
| 235 |
+
sequence = sequence.reindex(
|
| 236 |
+
pd.date_range(start_date, periods=self.total_length, freq="D")
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Get station metadata
|
| 240 |
+
station_meta = self.stations[self.stations["station_id"] == station_id].iloc[0]
|
| 241 |
+
|
| 242 |
+
# Prepare features
|
| 243 |
+
feature_cols = self.input_variables or [
|
| 244 |
+
c for c in sequence.columns
|
| 245 |
+
if c not in ["station_id", "latitude", "longitude", "elevation", "year"]
|
| 246 |
+
]
|
| 247 |
+
|
| 248 |
+
# Input sequence
|
| 249 |
+
input_seq = sequence.iloc[:self.sequence_length]
|
| 250 |
+
input_features = input_seq[feature_cols].values.astype(np.float32)
|
| 251 |
+
input_mask = ~np.isnan(input_features).any(axis=1)
|
| 252 |
+
|
| 253 |
+
# Target sequence
|
| 254 |
+
target_seq = sequence.iloc[self.sequence_length:]
|
| 255 |
+
target_cols = [c for c in self.target_variables if c in sequence.columns]
|
| 256 |
+
target_features = target_seq[target_cols].values.astype(np.float32)
|
| 257 |
+
target_mask = ~np.isnan(target_features).any(axis=1)
|
| 258 |
+
|
| 259 |
+
# Fill NaN with 0 for tensor conversion (mask indicates valid values)
|
| 260 |
+
input_features = np.nan_to_num(input_features, nan=0.0)
|
| 261 |
+
target_features = np.nan_to_num(target_features, nan=0.0)
|
| 262 |
+
|
| 263 |
+
# Station coordinates
|
| 264 |
+
station_coords = np.array([
|
| 265 |
+
station_meta["latitude"],
|
| 266 |
+
station_meta["longitude"],
|
| 267 |
+
station_meta["elevation"],
|
| 268 |
+
], dtype=np.float32)
|
| 269 |
+
|
| 270 |
+
return {
|
| 271 |
+
"input_features": torch.from_numpy(input_features),
|
| 272 |
+
"input_mask": torch.from_numpy(input_mask),
|
| 273 |
+
"target_features": torch.from_numpy(target_features),
|
| 274 |
+
"target_mask": torch.from_numpy(target_mask),
|
| 275 |
+
"station_coords": torch.from_numpy(station_coords),
|
| 276 |
+
"station_id": station_id,
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class StationDataModule:
|
| 281 |
+
"""
|
| 282 |
+
Data module for managing train/val/test splits.
|
| 283 |
+
|
| 284 |
+
Provides DataLoaders with proper batching and shuffling.
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
def __init__(
|
| 288 |
+
self,
|
| 289 |
+
data_dir: Union[str, Path],
|
| 290 |
+
batch_size: int = 32,
|
| 291 |
+
num_workers: int = 4,
|
| 292 |
+
train_ratio: float = 0.8,
|
| 293 |
+
val_ratio: float = 0.1,
|
| 294 |
+
sequence_length: int = 365,
|
| 295 |
+
forecast_length: int = 90,
|
| 296 |
+
**dataset_kwargs,
|
| 297 |
+
):
|
| 298 |
+
self.data_dir = Path(data_dir)
|
| 299 |
+
self.batch_size = batch_size
|
| 300 |
+
self.num_workers = num_workers
|
| 301 |
+
self.train_ratio = train_ratio
|
| 302 |
+
self.val_ratio = val_ratio
|
| 303 |
+
self.sequence_length = sequence_length
|
| 304 |
+
self.forecast_length = forecast_length
|
| 305 |
+
self.dataset_kwargs = dataset_kwargs
|
| 306 |
+
|
| 307 |
+
self._train_dataset: Optional[StationDataset] = None
|
| 308 |
+
self._val_dataset: Optional[StationDataset] = None
|
| 309 |
+
self._test_dataset: Optional[StationDataset] = None
|
| 310 |
+
|
| 311 |
+
def setup(self) -> None:
|
| 312 |
+
"""Set up train/val/test datasets."""
|
| 313 |
+
# Load all stations
|
| 314 |
+
stations = pd.read_parquet(self.data_dir / "stations.parquet")
|
| 315 |
+
all_station_ids = stations["station_id"].tolist()
|
| 316 |
+
|
| 317 |
+
# Shuffle and split
|
| 318 |
+
np.random.seed(42)
|
| 319 |
+
np.random.shuffle(all_station_ids)
|
| 320 |
+
|
| 321 |
+
n_train = int(len(all_station_ids) * self.train_ratio)
|
| 322 |
+
n_val = int(len(all_station_ids) * self.val_ratio)
|
| 323 |
+
|
| 324 |
+
train_ids = all_station_ids[:n_train]
|
| 325 |
+
val_ids = all_station_ids[n_train:n_train + n_val]
|
| 326 |
+
test_ids = all_station_ids[n_train + n_val:]
|
| 327 |
+
|
| 328 |
+
# Create datasets
|
| 329 |
+
common_kwargs = {
|
| 330 |
+
"data_dir": self.data_dir,
|
| 331 |
+
"sequence_length": self.sequence_length,
|
| 332 |
+
"forecast_length": self.forecast_length,
|
| 333 |
+
**self.dataset_kwargs,
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
self._train_dataset = StationDataset(station_ids=train_ids, **common_kwargs)
|
| 337 |
+
self._val_dataset = StationDataset(station_ids=val_ids, **common_kwargs)
|
| 338 |
+
self._test_dataset = StationDataset(station_ids=test_ids, **common_kwargs)
|
| 339 |
+
|
| 340 |
+
logger.info(
|
| 341 |
+
f"Data split: {len(self._train_dataset)} train, "
|
| 342 |
+
f"{len(self._val_dataset)} val, {len(self._test_dataset)} test"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
def train_dataloader(self) -> DataLoader:
|
| 346 |
+
"""Get training DataLoader."""
|
| 347 |
+
if self._train_dataset is None:
|
| 348 |
+
self.setup()
|
| 349 |
+
return DataLoader(
|
| 350 |
+
self._train_dataset,
|
| 351 |
+
batch_size=self.batch_size,
|
| 352 |
+
shuffle=True,
|
| 353 |
+
num_workers=self.num_workers,
|
| 354 |
+
pin_memory=True,
|
| 355 |
+
drop_last=True,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
def val_dataloader(self) -> DataLoader:
|
| 359 |
+
"""Get validation DataLoader."""
|
| 360 |
+
if self._val_dataset is None:
|
| 361 |
+
self.setup()
|
| 362 |
+
return DataLoader(
|
| 363 |
+
self._val_dataset,
|
| 364 |
+
batch_size=self.batch_size,
|
| 365 |
+
shuffle=False,
|
| 366 |
+
num_workers=self.num_workers,
|
| 367 |
+
pin_memory=True,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
def test_dataloader(self) -> DataLoader:
|
| 371 |
+
"""Get test DataLoader."""
|
| 372 |
+
if self._test_dataset is None:
|
| 373 |
+
self.setup()
|
| 374 |
+
return DataLoader(
|
| 375 |
+
self._test_dataset,
|
| 376 |
+
batch_size=self.batch_size,
|
| 377 |
+
shuffle=False,
|
| 378 |
+
num_workers=self.num_workers,
|
| 379 |
+
pin_memory=True,
|
| 380 |
+
)
|
data/processed/ghcn_combined.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:425012f5baaed11b241efa923cfbeee6e7c9d5d775a0346fc38afd531903a3ca
|
| 3 |
+
size 44173477
|
data/processed/training/X.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d75e18276c806a855b1257fef17e990bafedab047332da757fc0d3d7ba6cca15
|
| 3 |
+
size 413353928
|
data/processed/training/Y.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02dde247a9479cc2ec8105ea89c260ad624454589b545e669fcbd913bceb45a0
|
| 3 |
+
size 192898568
|
data/processed/training/meta.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c3ae1cc5c5c64f47f80c7fda04e936c0dd58182c0c72496bd46fce2483072513
|
| 3 |
+
size 18371408
|
data/processed/training/stats.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6ae13e3bc39888f5c60faf71fad8cd307f147672db1d0fffdc82431a0a00edb1
|
| 3 |
+
size 1042
|
data/processing/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data Processing Pipeline."""
|
| 2 |
+
|
| 3 |
+
from data.processing.quality_control import QualityController
|
| 4 |
+
from data.processing.pipeline import DataPipeline
|
| 5 |
+
|
| 6 |
+
__all__ = ["QualityController", "DataPipeline"]
|
data/processing/ghcn_processor.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GHCN Daily data processor - converts raw .dly files to training format
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, List, Optional, Tuple
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from datetime import datetime, timedelta
|
| 10 |
+
from loguru import logger
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class GHCNProcessor:
|
| 14 |
+
"""Process GHCN Daily files into training-ready format."""
|
| 15 |
+
|
| 16 |
+
# GHCN file format: fixed-width columns
|
| 17 |
+
# ID (11) + Year (4) + Month (2) + Element (4) + 31 * (Value(5) + MFlag(1) + QFlag(1) + SFlag(1))
|
| 18 |
+
|
| 19 |
+
ELEMENTS = ['TMAX', 'TMIN', 'PRCP', 'SNOW', 'SNWD']
|
| 20 |
+
MISSING_VALUE = -9999
|
| 21 |
+
|
| 22 |
+
def __init__(self, raw_dir: Path, processed_dir: Path, stations_file: Optional[Path] = None):
|
| 23 |
+
self.raw_dir = Path(raw_dir)
|
| 24 |
+
self.processed_dir = Path(processed_dir)
|
| 25 |
+
self.stations_file = stations_file
|
| 26 |
+
self.stations_dir = self.raw_dir / "stations"
|
| 27 |
+
self.processed_dir.mkdir(parents=True, exist_ok=True)
|
| 28 |
+
|
| 29 |
+
# Load station metadata if available
|
| 30 |
+
self.station_metadata = {}
|
| 31 |
+
if stations_file and stations_file.exists():
|
| 32 |
+
self._load_station_metadata()
|
| 33 |
+
|
| 34 |
+
def _load_station_metadata(self):
|
| 35 |
+
"""Load station lat/lon from stations file."""
|
| 36 |
+
with open(self.stations_file, 'r') as f:
|
| 37 |
+
for line in f:
|
| 38 |
+
# GHCN stations file format:
|
| 39 |
+
# ID (11) + LAT (9) + LON (10) + ELEV (7) + STATE (3) + NAME (31) + ...
|
| 40 |
+
station_id = line[0:11].strip()
|
| 41 |
+
lat = float(line[12:20].strip())
|
| 42 |
+
lon = float(line[21:30].strip())
|
| 43 |
+
elev = float(line[31:37].strip()) if line[31:37].strip() else 0.0
|
| 44 |
+
name = line[41:71].strip()
|
| 45 |
+
self.station_metadata[station_id] = {
|
| 46 |
+
'lat': lat,
|
| 47 |
+
'lon': lon,
|
| 48 |
+
'elevation': elev,
|
| 49 |
+
'name': name
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
def parse_dly_file(self, filepath: Path) -> pd.DataFrame:
|
| 53 |
+
"""Parse a single .dly file into a DataFrame."""
|
| 54 |
+
records = []
|
| 55 |
+
|
| 56 |
+
with open(filepath, 'r') as f:
|
| 57 |
+
for line in f:
|
| 58 |
+
if len(line) < 269: # Minimum valid line length
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
station_id = line[0:11]
|
| 62 |
+
year = int(line[11:15])
|
| 63 |
+
month = int(line[15:17])
|
| 64 |
+
element = line[17:21]
|
| 65 |
+
|
| 66 |
+
if element not in self.ELEMENTS:
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
# Parse 31 daily values
|
| 70 |
+
for day in range(1, 32):
|
| 71 |
+
try:
|
| 72 |
+
start = 21 + (day - 1) * 8
|
| 73 |
+
value_str = line[start:start+5].strip()
|
| 74 |
+
mflag = line[start+5:start+6]
|
| 75 |
+
qflag = line[start+6:start+7]
|
| 76 |
+
|
| 77 |
+
if not value_str:
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
value = int(value_str)
|
| 81 |
+
|
| 82 |
+
# Skip missing values and flagged quality issues
|
| 83 |
+
if value == self.MISSING_VALUE:
|
| 84 |
+
continue
|
| 85 |
+
if qflag.strip() not in ['', ' ']: # Has quality flag
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
# Create date
|
| 89 |
+
try:
|
| 90 |
+
date = datetime(year, month, day)
|
| 91 |
+
except ValueError:
|
| 92 |
+
continue # Invalid date (e.g., Feb 30)
|
| 93 |
+
|
| 94 |
+
records.append({
|
| 95 |
+
'station_id': station_id,
|
| 96 |
+
'date': date,
|
| 97 |
+
'element': element,
|
| 98 |
+
'value': value
|
| 99 |
+
})
|
| 100 |
+
except (ValueError, IndexError):
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
if not records:
|
| 104 |
+
return pd.DataFrame()
|
| 105 |
+
|
| 106 |
+
df = pd.DataFrame(records)
|
| 107 |
+
|
| 108 |
+
# Pivot to get elements as columns
|
| 109 |
+
df = df.pivot_table(
|
| 110 |
+
index=['station_id', 'date'],
|
| 111 |
+
columns='element',
|
| 112 |
+
values='value',
|
| 113 |
+
aggfunc='first'
|
| 114 |
+
).reset_index()
|
| 115 |
+
|
| 116 |
+
# Convert units: temps from tenths of °C, precip from tenths of mm
|
| 117 |
+
if 'TMAX' in df.columns:
|
| 118 |
+
df['TMAX'] = df['TMAX'] / 10.0
|
| 119 |
+
if 'TMIN' in df.columns:
|
| 120 |
+
df['TMIN'] = df['TMIN'] / 10.0
|
| 121 |
+
if 'PRCP' in df.columns:
|
| 122 |
+
df['PRCP'] = df['PRCP'] / 10.0
|
| 123 |
+
if 'SNOW' in df.columns:
|
| 124 |
+
df['SNOW'] = df['SNOW'] / 10.0
|
| 125 |
+
if 'SNWD' in df.columns:
|
| 126 |
+
df['SNWD'] = df['SNWD'] / 10.0
|
| 127 |
+
|
| 128 |
+
return df
|
| 129 |
+
|
| 130 |
+
def process_all_stations(self, min_years: int = 10) -> pd.DataFrame:
|
| 131 |
+
"""Process all station files and combine."""
|
| 132 |
+
all_data = []
|
| 133 |
+
station_files = list(self.stations_dir.glob("*.dly"))
|
| 134 |
+
|
| 135 |
+
logger.info(f"Processing {len(station_files)} station files...")
|
| 136 |
+
|
| 137 |
+
for i, filepath in enumerate(station_files):
|
| 138 |
+
if (i + 1) % 50 == 0:
|
| 139 |
+
logger.info(f"Processed {i + 1}/{len(station_files)} stations")
|
| 140 |
+
|
| 141 |
+
df = self.parse_dly_file(filepath)
|
| 142 |
+
if df.empty:
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
# Check if station has enough data
|
| 146 |
+
years_of_data = (df['date'].max() - df['date'].min()).days / 365
|
| 147 |
+
if years_of_data < min_years:
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
# Add station metadata
|
| 151 |
+
station_id = filepath.stem
|
| 152 |
+
if station_id in self.station_metadata:
|
| 153 |
+
meta = self.station_metadata[station_id]
|
| 154 |
+
df['lat'] = meta['lat']
|
| 155 |
+
df['lon'] = meta['lon']
|
| 156 |
+
df['elevation'] = meta['elevation']
|
| 157 |
+
|
| 158 |
+
all_data.append(df)
|
| 159 |
+
|
| 160 |
+
if not all_data:
|
| 161 |
+
logger.error("No valid station data found!")
|
| 162 |
+
return pd.DataFrame()
|
| 163 |
+
|
| 164 |
+
combined = pd.concat(all_data, ignore_index=True)
|
| 165 |
+
logger.success(f"Combined {len(combined)} records from {len(all_data)} stations")
|
| 166 |
+
|
| 167 |
+
return combined
|
| 168 |
+
|
| 169 |
+
def create_training_sequences(
|
| 170 |
+
self,
|
| 171 |
+
df: pd.DataFrame,
|
| 172 |
+
input_days: int = 30,
|
| 173 |
+
target_days: int = 14,
|
| 174 |
+
stride: int = 7
|
| 175 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 176 |
+
"""
|
| 177 |
+
Create training sequences for the model.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
df: DataFrame with processed weather data
|
| 181 |
+
input_days: Number of days of history to use as input
|
| 182 |
+
target_days: Number of days to predict
|
| 183 |
+
stride: Step size between sequences
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
X: Input sequences [N, input_days, features]
|
| 187 |
+
Y: Target sequences [N, target_days, features]
|
| 188 |
+
meta: Station metadata [N, 4] (lat, lon, elev, day_of_year)
|
| 189 |
+
"""
|
| 190 |
+
sequences_X = []
|
| 191 |
+
sequences_Y = []
|
| 192 |
+
sequences_meta = []
|
| 193 |
+
|
| 194 |
+
# Features we'll use
|
| 195 |
+
features = ['TMAX', 'TMIN', 'PRCP']
|
| 196 |
+
|
| 197 |
+
# Process each station separately
|
| 198 |
+
stations = df['station_id'].unique()
|
| 199 |
+
logger.info(f"Creating sequences from {len(stations)} stations...")
|
| 200 |
+
|
| 201 |
+
for station_id in stations:
|
| 202 |
+
station_df = df[df['station_id'] == station_id].copy()
|
| 203 |
+
station_df = station_df.sort_values('date')
|
| 204 |
+
|
| 205 |
+
# Ensure we have required features
|
| 206 |
+
for feat in features:
|
| 207 |
+
if feat not in station_df.columns:
|
| 208 |
+
station_df[feat] = np.nan
|
| 209 |
+
|
| 210 |
+
# Fill missing values with interpolation
|
| 211 |
+
station_df[features] = station_df[features].interpolate(method='linear', limit=7)
|
| 212 |
+
|
| 213 |
+
# Drop rows with too many NaN
|
| 214 |
+
station_df = station_df.dropna(subset=['TMAX', 'TMIN'])
|
| 215 |
+
|
| 216 |
+
if len(station_df) < input_days + target_days:
|
| 217 |
+
continue
|
| 218 |
+
|
| 219 |
+
# Get metadata
|
| 220 |
+
lat = station_df['lat'].iloc[0] if 'lat' in station_df.columns else 0
|
| 221 |
+
lon = station_df['lon'].iloc[0] if 'lon' in station_df.columns else 0
|
| 222 |
+
elev = station_df['elevation'].iloc[0] if 'elevation' in station_df.columns else 0
|
| 223 |
+
|
| 224 |
+
# Create sequences
|
| 225 |
+
values = station_df[features].values
|
| 226 |
+
dates = station_df['date'].values
|
| 227 |
+
|
| 228 |
+
for i in range(0, len(values) - input_days - target_days, stride):
|
| 229 |
+
X = values[i:i + input_days]
|
| 230 |
+
Y = values[i + input_days:i + input_days + target_days]
|
| 231 |
+
|
| 232 |
+
# Skip if too many NaN
|
| 233 |
+
if np.isnan(X).sum() > input_days * len(features) * 0.3:
|
| 234 |
+
continue
|
| 235 |
+
if np.isnan(Y).sum() > target_days * len(features) * 0.3:
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
# Fill remaining NaN with mean
|
| 239 |
+
X = np.nan_to_num(X, nan=np.nanmean(X))
|
| 240 |
+
Y = np.nan_to_num(Y, nan=np.nanmean(Y))
|
| 241 |
+
|
| 242 |
+
# Get day of year for the first target day
|
| 243 |
+
target_date = pd.Timestamp(dates[i + input_days])
|
| 244 |
+
day_of_year = target_date.dayofyear / 365.0 # Normalize
|
| 245 |
+
|
| 246 |
+
sequences_X.append(X)
|
| 247 |
+
sequences_Y.append(Y)
|
| 248 |
+
sequences_meta.append([lat, lon, elev, day_of_year])
|
| 249 |
+
|
| 250 |
+
if not sequences_X:
|
| 251 |
+
logger.error("No valid sequences created!")
|
| 252 |
+
return np.array([]), np.array([]), np.array([])
|
| 253 |
+
|
| 254 |
+
X = np.array(sequences_X, dtype=np.float32)
|
| 255 |
+
Y = np.array(sequences_Y, dtype=np.float32)
|
| 256 |
+
meta = np.array(sequences_meta, dtype=np.float32)
|
| 257 |
+
|
| 258 |
+
logger.success(f"Created {len(X)} training sequences")
|
| 259 |
+
logger.info(f"X shape: {X.shape}, Y shape: {Y.shape}, meta shape: {meta.shape}")
|
| 260 |
+
|
| 261 |
+
return X, Y, meta
|
| 262 |
+
|
| 263 |
+
def save_training_data(self, X: np.ndarray, Y: np.ndarray, meta: np.ndarray):
|
| 264 |
+
"""Save processed training data."""
|
| 265 |
+
output_dir = self.processed_dir / "training"
|
| 266 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 267 |
+
|
| 268 |
+
np.save(output_dir / "X.npy", X)
|
| 269 |
+
np.save(output_dir / "Y.npy", Y)
|
| 270 |
+
np.save(output_dir / "meta.npy", meta)
|
| 271 |
+
|
| 272 |
+
logger.success(f"Saved training data to {output_dir}")
|
| 273 |
+
|
| 274 |
+
# Save normalization stats
|
| 275 |
+
stats = {
|
| 276 |
+
'X_mean': X.mean(axis=(0, 1)),
|
| 277 |
+
'X_std': X.std(axis=(0, 1)),
|
| 278 |
+
'Y_mean': Y.mean(axis=(0, 1)),
|
| 279 |
+
'Y_std': Y.std(axis=(0, 1)),
|
| 280 |
+
}
|
| 281 |
+
np.savez(output_dir / "stats.npz", **stats)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def main():
|
| 285 |
+
"""Process GHCN data for training."""
|
| 286 |
+
from pathlib import Path
|
| 287 |
+
|
| 288 |
+
base_dir = Path(__file__).parent.parent.parent
|
| 289 |
+
raw_dir = base_dir / "data" / "raw" / "ghcn_daily"
|
| 290 |
+
processed_dir = base_dir / "data" / "processed"
|
| 291 |
+
stations_file = raw_dir / "ghcnd-stations.txt"
|
| 292 |
+
|
| 293 |
+
processor = GHCNProcessor(raw_dir, processed_dir, stations_file)
|
| 294 |
+
|
| 295 |
+
# Process all stations
|
| 296 |
+
df = processor.process_all_stations(min_years=10)
|
| 297 |
+
|
| 298 |
+
if df.empty:
|
| 299 |
+
logger.error("No data to process!")
|
| 300 |
+
return
|
| 301 |
+
|
| 302 |
+
# Save intermediate CSV for inspection
|
| 303 |
+
df.to_parquet(processed_dir / "ghcn_combined.parquet")
|
| 304 |
+
logger.info(f"Saved combined data to {processed_dir / 'ghcn_combined.parquet'}")
|
| 305 |
+
|
| 306 |
+
# Create training sequences
|
| 307 |
+
X, Y, meta = processor.create_training_sequences(
|
| 308 |
+
df,
|
| 309 |
+
input_days=30,
|
| 310 |
+
target_days=14,
|
| 311 |
+
stride=7
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
if len(X) > 0:
|
| 315 |
+
processor.save_training_data(X, Y, meta)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
if __name__ == "__main__":
|
| 319 |
+
main()
|
data/processing/pipeline.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data Processing Pipeline
|
| 3 |
+
|
| 4 |
+
Orchestrates the full data processing workflow:
|
| 5 |
+
1. Load raw GHCN data
|
| 6 |
+
2. Apply quality control
|
| 7 |
+
3. Normalize and encode features
|
| 8 |
+
4. Grid data (station → regular grid)
|
| 9 |
+
5. Save to efficient formats (Parquet/Zarr)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Optional, List
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
import pyarrow as pa
|
| 19 |
+
import pyarrow.parquet as pq
|
| 20 |
+
from loguru import logger
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
from data.download.ghcn_daily import GHCNDailyDownloader
|
| 24 |
+
from data.processing.quality_control import QualityController
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class PipelineConfig:
|
| 29 |
+
"""Configuration for the data pipeline."""
|
| 30 |
+
|
| 31 |
+
# Input/Output
|
| 32 |
+
raw_dir: str = "data/raw/ghcn_daily"
|
| 33 |
+
output_dir: str = "data/storage/parquet"
|
| 34 |
+
tensor_dir: str = "data/storage/zarr"
|
| 35 |
+
|
| 36 |
+
# Processing
|
| 37 |
+
min_years: int = 30
|
| 38 |
+
min_observations_per_year: int = 300
|
| 39 |
+
target_variables: List[str] = None
|
| 40 |
+
|
| 41 |
+
# Normalization
|
| 42 |
+
normalize: bool = True
|
| 43 |
+
clip_outliers: bool = True
|
| 44 |
+
outlier_std: float = 5.0
|
| 45 |
+
|
| 46 |
+
# Gridding
|
| 47 |
+
grid_resolution: float = 0.25 # degrees
|
| 48 |
+
interpolation_method: str = "idw" # 'idw', 'kriging', 'nearest'
|
| 49 |
+
max_interpolation_distance: float = 2.0 # degrees
|
| 50 |
+
|
| 51 |
+
def __post_init__(self):
|
| 52 |
+
if self.target_variables is None:
|
| 53 |
+
self.target_variables = ["TMAX", "TMIN", "PRCP", "SNOW", "SNWD"]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class FeatureEncoder:
|
| 57 |
+
"""
|
| 58 |
+
Encodes and normalizes weather features for ML training.
|
| 59 |
+
|
| 60 |
+
Handles:
|
| 61 |
+
- Cyclical encoding for time features (day of year, hour)
|
| 62 |
+
- Log transformation for precipitation
|
| 63 |
+
- Standard normalization for temperatures
|
| 64 |
+
- Sin/cos encoding for wind direction
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self):
|
| 68 |
+
self.stats: dict[str, dict[str, float]] = {}
|
| 69 |
+
|
| 70 |
+
def fit(self, df: pd.DataFrame) -> "FeatureEncoder":
|
| 71 |
+
"""Compute normalization statistics from data."""
|
| 72 |
+
for col in df.select_dtypes(include=[np.number]).columns:
|
| 73 |
+
self.stats[col] = {
|
| 74 |
+
"mean": df[col].mean(),
|
| 75 |
+
"std": df[col].std(),
|
| 76 |
+
"min": df[col].min(),
|
| 77 |
+
"max": df[col].max(),
|
| 78 |
+
}
|
| 79 |
+
return self
|
| 80 |
+
|
| 81 |
+
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 82 |
+
"""Apply encoding and normalization."""
|
| 83 |
+
result = df.copy()
|
| 84 |
+
|
| 85 |
+
# Add time features
|
| 86 |
+
if isinstance(df.index, pd.DatetimeIndex):
|
| 87 |
+
# Day of year (cyclical)
|
| 88 |
+
day_of_year = df.index.dayofyear
|
| 89 |
+
result["day_sin"] = np.sin(2 * np.pi * day_of_year / 365)
|
| 90 |
+
result["day_cos"] = np.cos(2 * np.pi * day_of_year / 365)
|
| 91 |
+
|
| 92 |
+
# Month (cyclical)
|
| 93 |
+
month = df.index.month
|
| 94 |
+
result["month_sin"] = np.sin(2 * np.pi * month / 12)
|
| 95 |
+
result["month_cos"] = np.cos(2 * np.pi * month / 12)
|
| 96 |
+
|
| 97 |
+
# Normalize numerical columns
|
| 98 |
+
for col in df.select_dtypes(include=[np.number]).columns:
|
| 99 |
+
if col in self.stats:
|
| 100 |
+
stats = self.stats[col]
|
| 101 |
+
|
| 102 |
+
# Special handling for precipitation (log transform)
|
| 103 |
+
if "prcp" in col.lower() or "precip" in col.lower():
|
| 104 |
+
# Log1p transform for precipitation
|
| 105 |
+
result[col] = np.log1p(df[col].clip(lower=0))
|
| 106 |
+
else:
|
| 107 |
+
# Standard normalization
|
| 108 |
+
if stats["std"] > 0:
|
| 109 |
+
result[col] = (df[col] - stats["mean"]) / stats["std"]
|
| 110 |
+
else:
|
| 111 |
+
result[col] = 0.0
|
| 112 |
+
|
| 113 |
+
# Wind direction encoding (if present)
|
| 114 |
+
for col in ["wind_direction", "WDIR"]:
|
| 115 |
+
if col in df.columns:
|
| 116 |
+
rad = np.deg2rad(df[col])
|
| 117 |
+
result[f"{col}_sin"] = np.sin(rad)
|
| 118 |
+
result[f"{col}_cos"] = np.cos(rad)
|
| 119 |
+
result = result.drop(columns=[col])
|
| 120 |
+
|
| 121 |
+
return result
|
| 122 |
+
|
| 123 |
+
def inverse_transform(self, df: pd.DataFrame, columns: Optional[List[str]] = None) -> pd.DataFrame:
|
| 124 |
+
"""Reverse normalization for predictions."""
|
| 125 |
+
result = df.copy()
|
| 126 |
+
columns = columns or list(self.stats.keys())
|
| 127 |
+
|
| 128 |
+
for col in columns:
|
| 129 |
+
if col not in self.stats or col not in df.columns:
|
| 130 |
+
continue
|
| 131 |
+
|
| 132 |
+
stats = self.stats[col]
|
| 133 |
+
|
| 134 |
+
if "prcp" in col.lower() or "precip" in col.lower():
|
| 135 |
+
# Reverse log1p
|
| 136 |
+
result[col] = np.expm1(df[col])
|
| 137 |
+
else:
|
| 138 |
+
# Reverse standard normalization
|
| 139 |
+
result[col] = df[col] * stats["std"] + stats["mean"]
|
| 140 |
+
|
| 141 |
+
return result
|
| 142 |
+
|
| 143 |
+
def save(self, path: str) -> None:
|
| 144 |
+
"""Save encoder statistics to file."""
|
| 145 |
+
import json
|
| 146 |
+
|
| 147 |
+
with open(path, "w") as f:
|
| 148 |
+
json.dump(self.stats, f)
|
| 149 |
+
|
| 150 |
+
@classmethod
|
| 151 |
+
def load(cls, path: str) -> "FeatureEncoder":
|
| 152 |
+
"""Load encoder from file."""
|
| 153 |
+
import json
|
| 154 |
+
|
| 155 |
+
encoder = cls()
|
| 156 |
+
with open(path) as f:
|
| 157 |
+
encoder.stats = json.load(f)
|
| 158 |
+
return encoder
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class SpatialGridder:
|
| 162 |
+
"""
|
| 163 |
+
Converts irregular station data to regular lat/lon grid.
|
| 164 |
+
|
| 165 |
+
Uses inverse distance weighting (IDW) or other interpolation methods
|
| 166 |
+
to create gridded fields from station observations.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
resolution: float = 0.25,
|
| 172 |
+
method: str = "idw",
|
| 173 |
+
max_distance: float = 2.0,
|
| 174 |
+
power: float = 2.0,
|
| 175 |
+
):
|
| 176 |
+
self.resolution = resolution
|
| 177 |
+
self.method = method
|
| 178 |
+
self.max_distance = max_distance
|
| 179 |
+
self.power = power
|
| 180 |
+
|
| 181 |
+
# Create grid
|
| 182 |
+
self.lat_grid = np.arange(-90, 90 + resolution, resolution)
|
| 183 |
+
self.lon_grid = np.arange(-180, 180, resolution)
|
| 184 |
+
|
| 185 |
+
def grid_stations(
|
| 186 |
+
self,
|
| 187 |
+
stations: pd.DataFrame,
|
| 188 |
+
variable: str,
|
| 189 |
+
) -> np.ndarray:
|
| 190 |
+
"""
|
| 191 |
+
Grid station observations to regular grid.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
stations: DataFrame with columns ['latitude', 'longitude', variable]
|
| 195 |
+
variable: Column name to grid
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
2D array of shape (n_lat, n_lon)
|
| 199 |
+
"""
|
| 200 |
+
# Initialize output grid
|
| 201 |
+
grid = np.full((len(self.lat_grid), len(self.lon_grid)), np.nan)
|
| 202 |
+
|
| 203 |
+
# Get valid stations
|
| 204 |
+
valid = stations[["latitude", "longitude", variable]].dropna()
|
| 205 |
+
if len(valid) == 0:
|
| 206 |
+
return grid
|
| 207 |
+
|
| 208 |
+
station_lats = valid["latitude"].values
|
| 209 |
+
station_lons = valid["longitude"].values
|
| 210 |
+
station_vals = valid[variable].values
|
| 211 |
+
|
| 212 |
+
# IDW interpolation
|
| 213 |
+
for i, lat in enumerate(self.lat_grid):
|
| 214 |
+
for j, lon in enumerate(self.lon_grid):
|
| 215 |
+
# Calculate distances to all stations
|
| 216 |
+
dlat = station_lats - lat
|
| 217 |
+
dlon = station_lons - lon
|
| 218 |
+
|
| 219 |
+
# Approximate distance in degrees
|
| 220 |
+
distances = np.sqrt(dlat**2 + dlon**2)
|
| 221 |
+
|
| 222 |
+
# Find stations within max distance
|
| 223 |
+
mask = distances < self.max_distance
|
| 224 |
+
if not mask.any():
|
| 225 |
+
continue
|
| 226 |
+
|
| 227 |
+
nearby_distances = distances[mask]
|
| 228 |
+
nearby_values = station_vals[mask]
|
| 229 |
+
|
| 230 |
+
# Handle exact matches (distance = 0)
|
| 231 |
+
if (nearby_distances == 0).any():
|
| 232 |
+
grid[i, j] = nearby_values[nearby_distances == 0][0]
|
| 233 |
+
else:
|
| 234 |
+
# IDW weights
|
| 235 |
+
weights = 1.0 / (nearby_distances ** self.power)
|
| 236 |
+
grid[i, j] = np.average(nearby_values, weights=weights)
|
| 237 |
+
|
| 238 |
+
return grid
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class DataPipeline:
|
| 242 |
+
"""
|
| 243 |
+
Main data processing pipeline.
|
| 244 |
+
|
| 245 |
+
Coordinates downloading, quality control, encoding, and output.
|
| 246 |
+
|
| 247 |
+
Example usage:
|
| 248 |
+
pipeline = DataPipeline(config)
|
| 249 |
+
pipeline.run()
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
def __init__(self, config: Optional[PipelineConfig] = None):
|
| 253 |
+
self.config = config or PipelineConfig()
|
| 254 |
+
self.downloader = GHCNDailyDownloader(output_dir=self.config.raw_dir)
|
| 255 |
+
self.qc = QualityController()
|
| 256 |
+
self.encoder = FeatureEncoder()
|
| 257 |
+
self.gridder = SpatialGridder(resolution=self.config.grid_resolution)
|
| 258 |
+
|
| 259 |
+
# Ensure output directories exist
|
| 260 |
+
Path(self.config.output_dir).mkdir(parents=True, exist_ok=True)
|
| 261 |
+
Path(self.config.tensor_dir).mkdir(parents=True, exist_ok=True)
|
| 262 |
+
|
| 263 |
+
def run(
|
| 264 |
+
self,
|
| 265 |
+
stations: Optional[list] = None,
|
| 266 |
+
max_stations: Optional[int] = None,
|
| 267 |
+
download: bool = True,
|
| 268 |
+
) -> None:
|
| 269 |
+
"""
|
| 270 |
+
Run the full pipeline.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
stations: List of stations to process (or download new)
|
| 274 |
+
max_stations: Maximum number of stations to process
|
| 275 |
+
download: Whether to download data if not present
|
| 276 |
+
"""
|
| 277 |
+
logger.info("Starting data pipeline")
|
| 278 |
+
|
| 279 |
+
# 1. Get stations
|
| 280 |
+
if stations is None:
|
| 281 |
+
if download:
|
| 282 |
+
self.downloader.download_stations()
|
| 283 |
+
self.downloader.download_inventory()
|
| 284 |
+
|
| 285 |
+
stations = self.downloader.get_stations(
|
| 286 |
+
min_years=self.config.min_years,
|
| 287 |
+
elements=self.config.target_variables,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
if max_stations:
|
| 291 |
+
stations = stations[:max_stations]
|
| 292 |
+
|
| 293 |
+
logger.info(f"Processing {len(stations)} stations")
|
| 294 |
+
|
| 295 |
+
# 2. Process each station
|
| 296 |
+
all_data = []
|
| 297 |
+
station_metadata = []
|
| 298 |
+
|
| 299 |
+
for station in tqdm(stations, desc="Processing stations"):
|
| 300 |
+
try:
|
| 301 |
+
# Download if needed
|
| 302 |
+
if download:
|
| 303 |
+
self.downloader.download_station_data(station.id)
|
| 304 |
+
|
| 305 |
+
# Load and process
|
| 306 |
+
df = self.downloader.station_to_dataframe(station.id)
|
| 307 |
+
if df.empty:
|
| 308 |
+
continue
|
| 309 |
+
|
| 310 |
+
# Quality control
|
| 311 |
+
df_clean, flags = self.qc.process(df, station_id=station.id)
|
| 312 |
+
|
| 313 |
+
# Fill small gaps
|
| 314 |
+
df_clean, fill_flags = self.qc.fill_gaps(df_clean)
|
| 315 |
+
|
| 316 |
+
# Filter to target variables
|
| 317 |
+
target_cols = [c for c in self.config.target_variables if c in df_clean.columns]
|
| 318 |
+
if not target_cols:
|
| 319 |
+
continue
|
| 320 |
+
|
| 321 |
+
df_clean = df_clean[target_cols]
|
| 322 |
+
|
| 323 |
+
# Add station metadata
|
| 324 |
+
df_clean["station_id"] = station.id
|
| 325 |
+
df_clean["latitude"] = station.latitude
|
| 326 |
+
df_clean["longitude"] = station.longitude
|
| 327 |
+
df_clean["elevation"] = station.elevation
|
| 328 |
+
|
| 329 |
+
all_data.append(df_clean)
|
| 330 |
+
station_metadata.append({
|
| 331 |
+
"station_id": station.id,
|
| 332 |
+
"name": station.name,
|
| 333 |
+
"latitude": station.latitude,
|
| 334 |
+
"longitude": station.longitude,
|
| 335 |
+
"elevation": station.elevation,
|
| 336 |
+
"country": station.id[:2],
|
| 337 |
+
"start_date": df_clean.index.min().isoformat(),
|
| 338 |
+
"end_date": df_clean.index.max().isoformat(),
|
| 339 |
+
"n_observations": len(df_clean),
|
| 340 |
+
})
|
| 341 |
+
|
| 342 |
+
except Exception as e:
|
| 343 |
+
logger.warning(f"Error processing {station.id}: {e}")
|
| 344 |
+
continue
|
| 345 |
+
|
| 346 |
+
if not all_data:
|
| 347 |
+
logger.error("No data processed successfully")
|
| 348 |
+
return
|
| 349 |
+
|
| 350 |
+
# 3. Combine all data
|
| 351 |
+
logger.info("Combining station data")
|
| 352 |
+
combined = pd.concat(all_data)
|
| 353 |
+
|
| 354 |
+
# 4. Fit encoder on full dataset
|
| 355 |
+
logger.info("Fitting feature encoder")
|
| 356 |
+
numeric_cols = combined.select_dtypes(include=[np.number]).columns
|
| 357 |
+
numeric_cols = [c for c in numeric_cols if c not in ["latitude", "longitude", "elevation"]]
|
| 358 |
+
self.encoder.fit(combined[numeric_cols])
|
| 359 |
+
|
| 360 |
+
# 5. Save encoder
|
| 361 |
+
encoder_path = Path(self.config.output_dir) / "encoder.json"
|
| 362 |
+
self.encoder.save(str(encoder_path))
|
| 363 |
+
logger.info(f"Saved encoder to {encoder_path}")
|
| 364 |
+
|
| 365 |
+
# 6. Save station metadata
|
| 366 |
+
metadata_df = pd.DataFrame(station_metadata)
|
| 367 |
+
metadata_path = Path(self.config.output_dir) / "stations.parquet"
|
| 368 |
+
metadata_df.to_parquet(metadata_path)
|
| 369 |
+
logger.info(f"Saved {len(metadata_df)} stations to {metadata_path}")
|
| 370 |
+
|
| 371 |
+
# 7. Save processed data (partitioned by year)
|
| 372 |
+
logger.info("Saving processed data")
|
| 373 |
+
combined["year"] = combined.index.year
|
| 374 |
+
|
| 375 |
+
for year, year_data in combined.groupby("year"):
|
| 376 |
+
year_path = Path(self.config.output_dir) / f"observations_{year}.parquet"
|
| 377 |
+
year_data.to_parquet(year_path)
|
| 378 |
+
|
| 379 |
+
logger.success(f"Pipeline complete. Processed {len(station_metadata)} stations, {len(combined)} observations")
|
| 380 |
+
|
| 381 |
+
def create_training_tensors(
|
| 382 |
+
self,
|
| 383 |
+
start_year: int = 1950,
|
| 384 |
+
end_year: int = 2023,
|
| 385 |
+
sequence_length: int = 365,
|
| 386 |
+
) -> None:
|
| 387 |
+
"""
|
| 388 |
+
Create training tensors from processed data.
|
| 389 |
+
|
| 390 |
+
Outputs Zarr arrays suitable for PyTorch DataLoaders.
|
| 391 |
+
"""
|
| 392 |
+
import zarr
|
| 393 |
+
|
| 394 |
+
logger.info(f"Creating training tensors for {start_year}-{end_year}")
|
| 395 |
+
|
| 396 |
+
output_path = Path(self.config.tensor_dir)
|
| 397 |
+
|
| 398 |
+
# Load encoder
|
| 399 |
+
encoder_path = Path(self.config.output_dir) / "encoder.json"
|
| 400 |
+
if encoder_path.exists():
|
| 401 |
+
self.encoder = FeatureEncoder.load(str(encoder_path))
|
| 402 |
+
|
| 403 |
+
# Load station metadata
|
| 404 |
+
stations = pd.read_parquet(Path(self.config.output_dir) / "stations.parquet")
|
| 405 |
+
|
| 406 |
+
# Initialize Zarr store
|
| 407 |
+
store = zarr.DirectoryStore(str(output_path / "training"))
|
| 408 |
+
root = zarr.group(store)
|
| 409 |
+
|
| 410 |
+
# Process year by year
|
| 411 |
+
all_features = []
|
| 412 |
+
all_targets = []
|
| 413 |
+
all_station_ids = []
|
| 414 |
+
all_timestamps = []
|
| 415 |
+
|
| 416 |
+
for year in tqdm(range(start_year, end_year + 1), desc="Years"):
|
| 417 |
+
year_path = Path(self.config.output_dir) / f"observations_{year}.parquet"
|
| 418 |
+
if not year_path.exists():
|
| 419 |
+
continue
|
| 420 |
+
|
| 421 |
+
df = pd.read_parquet(year_path)
|
| 422 |
+
|
| 423 |
+
# Encode features
|
| 424 |
+
encoded = self.encoder.transform(df[self.config.target_variables])
|
| 425 |
+
|
| 426 |
+
# Store
|
| 427 |
+
all_features.append(encoded.values)
|
| 428 |
+
all_station_ids.extend(df["station_id"].tolist())
|
| 429 |
+
all_timestamps.extend(df.index.tolist())
|
| 430 |
+
|
| 431 |
+
# Concatenate and save
|
| 432 |
+
if all_features:
|
| 433 |
+
features = np.concatenate(all_features, axis=0)
|
| 434 |
+
root.create_dataset("features", data=features, chunks=(10000, features.shape[1]))
|
| 435 |
+
root.attrs["n_samples"] = len(features)
|
| 436 |
+
root.attrs["feature_names"] = list(self.encoder.stats.keys())
|
| 437 |
+
|
| 438 |
+
logger.success(f"Created training tensors: {features.shape}")
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def main():
|
| 442 |
+
"""CLI entry point for the data pipeline."""
|
| 443 |
+
import argparse
|
| 444 |
+
|
| 445 |
+
parser = argparse.ArgumentParser(description="Run LILITH data pipeline")
|
| 446 |
+
parser.add_argument("--raw-dir", default="data/raw/ghcn_daily", help="Raw data directory")
|
| 447 |
+
parser.add_argument("--output-dir", default="data/storage/parquet", help="Output directory")
|
| 448 |
+
parser.add_argument("--max-stations", type=int, default=None, help="Max stations to process")
|
| 449 |
+
parser.add_argument("--min-years", type=int, default=30, help="Min years of data required")
|
| 450 |
+
parser.add_argument("--no-download", action="store_true", help="Don't download new data")
|
| 451 |
+
parser.add_argument("--create-tensors", action="store_true", help="Create training tensors")
|
| 452 |
+
|
| 453 |
+
args = parser.parse_args()
|
| 454 |
+
|
| 455 |
+
config = PipelineConfig(
|
| 456 |
+
raw_dir=args.raw_dir,
|
| 457 |
+
output_dir=args.output_dir,
|
| 458 |
+
min_years=args.min_years,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
pipeline = DataPipeline(config)
|
| 462 |
+
pipeline.run(max_stations=args.max_stations, download=not args.no_download)
|
| 463 |
+
|
| 464 |
+
if args.create_tensors:
|
| 465 |
+
pipeline.create_training_tensors()
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
if __name__ == "__main__":
|
| 469 |
+
main()
|
data/processing/quality_control.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quality Control for GHCN Data
|
| 3 |
+
|
| 4 |
+
Implements quality checks and cleaning procedures for weather observations.
|
| 5 |
+
Based on GHCN quality control flags and additional statistical checks.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from enum import Enum
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from loguru import logger
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class QCFlag(Enum):
|
| 18 |
+
"""Quality control flag values."""
|
| 19 |
+
|
| 20 |
+
PASSED = "P" # Passed all checks
|
| 21 |
+
DUPLICATE = "D" # Duplicate value
|
| 22 |
+
GAP_FILLED = "G" # Value was interpolated
|
| 23 |
+
SUSPECT_RANGE = "R" # Outside valid range
|
| 24 |
+
SUSPECT_SPATIAL = "S" # Spatial consistency check failed
|
| 25 |
+
SUSPECT_TEMPORAL = "T" # Temporal consistency check failed
|
| 26 |
+
SUSPECT_CLIMATE = "C" # Exceeds climatological bounds
|
| 27 |
+
FAILED = "F" # Failed quality check, value removed
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class QCConfig:
|
| 32 |
+
"""Configuration for quality control checks."""
|
| 33 |
+
|
| 34 |
+
# Temperature bounds (°C)
|
| 35 |
+
temp_min: float = -90.0
|
| 36 |
+
temp_max: float = 60.0
|
| 37 |
+
temp_daily_change_max: float = 30.0 # Max change between consecutive days
|
| 38 |
+
|
| 39 |
+
# Precipitation bounds (mm)
|
| 40 |
+
precip_min: float = 0.0
|
| 41 |
+
precip_max: float = 1000.0 # Single day max
|
| 42 |
+
|
| 43 |
+
# Wind bounds (m/s)
|
| 44 |
+
wind_min: float = 0.0
|
| 45 |
+
wind_max: float = 120.0
|
| 46 |
+
|
| 47 |
+
# Pressure bounds (hPa)
|
| 48 |
+
pressure_min: float = 870.0
|
| 49 |
+
pressure_max: float = 1085.0
|
| 50 |
+
|
| 51 |
+
# Spike detection
|
| 52 |
+
spike_threshold: float = 4.0 # Standard deviations
|
| 53 |
+
|
| 54 |
+
# Climatology bounds (number of standard deviations from monthly mean)
|
| 55 |
+
climate_std_threshold: float = 5.0
|
| 56 |
+
|
| 57 |
+
# Gap filling
|
| 58 |
+
max_gap_hours: int = 6 # Maximum gap to interpolate for hourly data
|
| 59 |
+
max_gap_days: int = 3 # Maximum gap to interpolate for daily data
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class QualityController:
|
| 63 |
+
"""
|
| 64 |
+
Applies quality control checks to weather observation data.
|
| 65 |
+
|
| 66 |
+
Checks include:
|
| 67 |
+
1. Range checks (physical bounds)
|
| 68 |
+
2. Temporal consistency (spike detection)
|
| 69 |
+
3. Spatial consistency (comparison with neighbors)
|
| 70 |
+
4. Climatological bounds
|
| 71 |
+
5. Duplicate detection
|
| 72 |
+
|
| 73 |
+
Example usage:
|
| 74 |
+
qc = QualityController()
|
| 75 |
+
df_clean, flags = qc.process(df)
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, config: Optional[QCConfig] = None):
|
| 79 |
+
self.config = config or QCConfig()
|
| 80 |
+
self._climatology: Optional[pd.DataFrame] = None
|
| 81 |
+
|
| 82 |
+
def process(
|
| 83 |
+
self,
|
| 84 |
+
df: pd.DataFrame,
|
| 85 |
+
station_id: Optional[str] = None,
|
| 86 |
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 87 |
+
"""
|
| 88 |
+
Apply all quality control checks to a DataFrame.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
df: DataFrame with datetime index and weather variable columns
|
| 92 |
+
station_id: Optional station identifier for logging
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Tuple of (cleaned_df, flags_df) where flags_df contains QC flags
|
| 96 |
+
"""
|
| 97 |
+
logger.info(f"Running QC on {len(df)} records" + (f" for {station_id}" if station_id else ""))
|
| 98 |
+
|
| 99 |
+
# Initialize flags DataFrame
|
| 100 |
+
flags = pd.DataFrame(index=df.index)
|
| 101 |
+
for col in df.columns:
|
| 102 |
+
flags[f"{col}_flag"] = QCFlag.PASSED.value
|
| 103 |
+
|
| 104 |
+
# Create working copy
|
| 105 |
+
df_clean = df.copy()
|
| 106 |
+
|
| 107 |
+
# 1. Range checks
|
| 108 |
+
df_clean, flags = self._range_check(df_clean, flags)
|
| 109 |
+
|
| 110 |
+
# 2. Temporal consistency (spike detection)
|
| 111 |
+
df_clean, flags = self._temporal_check(df_clean, flags)
|
| 112 |
+
|
| 113 |
+
# 3. Duplicate detection
|
| 114 |
+
df_clean, flags = self._duplicate_check(df_clean, flags)
|
| 115 |
+
|
| 116 |
+
# 4. Climatological bounds (if climatology is loaded)
|
| 117 |
+
if self._climatology is not None:
|
| 118 |
+
df_clean, flags = self._climate_check(df_clean, flags, station_id)
|
| 119 |
+
|
| 120 |
+
# Count flags
|
| 121 |
+
for col in df.columns:
|
| 122 |
+
flag_col = f"{col}_flag"
|
| 123 |
+
if flag_col in flags.columns:
|
| 124 |
+
flag_counts = flags[flag_col].value_counts()
|
| 125 |
+
for flag, count in flag_counts.items():
|
| 126 |
+
if flag != QCFlag.PASSED.value:
|
| 127 |
+
logger.debug(f"{col}: {count} records flagged as {flag}")
|
| 128 |
+
|
| 129 |
+
# Calculate overall pass rate
|
| 130 |
+
total_checks = len(df) * len(df.columns)
|
| 131 |
+
passed = sum(
|
| 132 |
+
(flags[f"{col}_flag"] == QCFlag.PASSED.value).sum()
|
| 133 |
+
for col in df.columns
|
| 134 |
+
if f"{col}_flag" in flags.columns
|
| 135 |
+
)
|
| 136 |
+
pass_rate = passed / total_checks if total_checks > 0 else 0
|
| 137 |
+
logger.info(f"QC pass rate: {pass_rate:.1%}")
|
| 138 |
+
|
| 139 |
+
return df_clean, flags
|
| 140 |
+
|
| 141 |
+
def _range_check(
|
| 142 |
+
self,
|
| 143 |
+
df: pd.DataFrame,
|
| 144 |
+
flags: pd.DataFrame,
|
| 145 |
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 146 |
+
"""Apply physical range checks."""
|
| 147 |
+
cfg = self.config
|
| 148 |
+
|
| 149 |
+
# Temperature columns
|
| 150 |
+
for col in ["TMAX", "TMIN", "TAVG", "temperature", "temp_mean", "temp_max", "temp_min"]:
|
| 151 |
+
if col in df.columns:
|
| 152 |
+
mask = (df[col] < cfg.temp_min) | (df[col] > cfg.temp_max)
|
| 153 |
+
flags.loc[mask, f"{col}_flag"] = QCFlag.SUSPECT_RANGE.value
|
| 154 |
+
df.loc[mask, col] = np.nan
|
| 155 |
+
|
| 156 |
+
# TMAX should be >= TMIN
|
| 157 |
+
if "TMAX" in df.columns and "TMIN" in df.columns:
|
| 158 |
+
mask = df["TMAX"] < df["TMIN"]
|
| 159 |
+
flags.loc[mask, "TMAX_flag"] = QCFlag.SUSPECT_RANGE.value
|
| 160 |
+
flags.loc[mask, "TMIN_flag"] = QCFlag.SUSPECT_RANGE.value
|
| 161 |
+
|
| 162 |
+
# Precipitation
|
| 163 |
+
for col in ["PRCP", "precipitation", "precip", "precipitation_1h", "precipitation_6h"]:
|
| 164 |
+
if col in df.columns:
|
| 165 |
+
mask = (df[col] < cfg.precip_min) | (df[col] > cfg.precip_max)
|
| 166 |
+
flags.loc[mask, f"{col}_flag"] = QCFlag.SUSPECT_RANGE.value
|
| 167 |
+
df.loc[mask, col] = np.nan
|
| 168 |
+
|
| 169 |
+
# Wind speed
|
| 170 |
+
for col in ["wind_speed", "AWND", "wind_gust"]:
|
| 171 |
+
if col in df.columns:
|
| 172 |
+
mask = (df[col] < cfg.wind_min) | (df[col] > cfg.wind_max)
|
| 173 |
+
flags.loc[mask, f"{col}_flag"] = QCFlag.SUSPECT_RANGE.value
|
| 174 |
+
df.loc[mask, col] = np.nan
|
| 175 |
+
|
| 176 |
+
# Pressure
|
| 177 |
+
for col in ["sea_level_pressure", "station_pressure", "pressure"]:
|
| 178 |
+
if col in df.columns:
|
| 179 |
+
mask = (df[col] < cfg.pressure_min) | (df[col] > cfg.pressure_max)
|
| 180 |
+
flags.loc[mask, f"{col}_flag"] = QCFlag.SUSPECT_RANGE.value
|
| 181 |
+
df.loc[mask, col] = np.nan
|
| 182 |
+
|
| 183 |
+
return df, flags
|
| 184 |
+
|
| 185 |
+
def _temporal_check(
|
| 186 |
+
self,
|
| 187 |
+
df: pd.DataFrame,
|
| 188 |
+
flags: pd.DataFrame,
|
| 189 |
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 190 |
+
"""
|
| 191 |
+
Check for temporal consistency (spike detection).
|
| 192 |
+
|
| 193 |
+
Uses a rolling window to detect values that deviate significantly
|
| 194 |
+
from their temporal neighbors.
|
| 195 |
+
"""
|
| 196 |
+
cfg = self.config
|
| 197 |
+
|
| 198 |
+
for col in df.columns:
|
| 199 |
+
if df[col].dtype not in [np.float64, np.float32, np.int64, np.int32]:
|
| 200 |
+
continue
|
| 201 |
+
|
| 202 |
+
# Calculate rolling statistics
|
| 203 |
+
window = 7 if "temp" in col.lower() or col in ["TMAX", "TMIN", "TAVG"] else 3
|
| 204 |
+
rolling_mean = df[col].rolling(window, center=True, min_periods=1).mean()
|
| 205 |
+
rolling_std = df[col].rolling(window, center=True, min_periods=1).std()
|
| 206 |
+
|
| 207 |
+
# Flag values that deviate too much from rolling mean
|
| 208 |
+
deviation = np.abs(df[col] - rolling_mean)
|
| 209 |
+
threshold = cfg.spike_threshold * rolling_std.clip(lower=0.1) # Minimum std
|
| 210 |
+
|
| 211 |
+
mask = deviation > threshold
|
| 212 |
+
mask = mask & ~df[col].isna() # Don't flag already-missing values
|
| 213 |
+
|
| 214 |
+
if mask.any():
|
| 215 |
+
# Update flags (don't overwrite worse flags)
|
| 216 |
+
current_flags = flags[f"{col}_flag"]
|
| 217 |
+
new_flags = current_flags.where(
|
| 218 |
+
current_flags != QCFlag.PASSED.value,
|
| 219 |
+
QCFlag.SUSPECT_TEMPORAL.value,
|
| 220 |
+
)
|
| 221 |
+
flags.loc[mask, f"{col}_flag"] = new_flags[mask]
|
| 222 |
+
|
| 223 |
+
# Optionally remove values (or just flag them)
|
| 224 |
+
# df.loc[mask, col] = np.nan
|
| 225 |
+
|
| 226 |
+
return df, flags
|
| 227 |
+
|
| 228 |
+
def _duplicate_check(
|
| 229 |
+
self,
|
| 230 |
+
df: pd.DataFrame,
|
| 231 |
+
flags: pd.DataFrame,
|
| 232 |
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 233 |
+
"""
|
| 234 |
+
Check for duplicate records.
|
| 235 |
+
|
| 236 |
+
Flags rows with identical timestamps or suspiciously repeated values.
|
| 237 |
+
"""
|
| 238 |
+
# Check for duplicate indices
|
| 239 |
+
if df.index.duplicated().any():
|
| 240 |
+
dup_mask = df.index.duplicated(keep="first")
|
| 241 |
+
for col in df.columns:
|
| 242 |
+
flag_col = f"{col}_flag"
|
| 243 |
+
if flag_col in flags.columns:
|
| 244 |
+
flags.loc[dup_mask, flag_col] = QCFlag.DUPLICATE.value
|
| 245 |
+
|
| 246 |
+
# Remove duplicates (keep first)
|
| 247 |
+
df = df[~df.index.duplicated(keep="first")]
|
| 248 |
+
flags = flags[~flags.index.duplicated(keep="first")]
|
| 249 |
+
|
| 250 |
+
# Check for stuck sensors (many repeated values)
|
| 251 |
+
for col in df.columns:
|
| 252 |
+
if df[col].dtype not in [np.float64, np.float32, np.int64, np.int32]:
|
| 253 |
+
continue
|
| 254 |
+
|
| 255 |
+
# Count consecutive identical values
|
| 256 |
+
shifted = df[col].shift(1)
|
| 257 |
+
same_as_prev = df[col] == shifted
|
| 258 |
+
consecutive_same = same_as_prev.groupby((~same_as_prev).cumsum()).cumsum()
|
| 259 |
+
|
| 260 |
+
# Flag if more than 5 consecutive identical values (possible stuck sensor)
|
| 261 |
+
stuck_mask = consecutive_same > 5
|
| 262 |
+
if stuck_mask.any():
|
| 263 |
+
logger.debug(f"Possible stuck sensor detected in {col}")
|
| 264 |
+
# Just log, don't automatically flag (could be valid calm conditions)
|
| 265 |
+
|
| 266 |
+
return df, flags
|
| 267 |
+
|
| 268 |
+
def _climate_check(
|
| 269 |
+
self,
|
| 270 |
+
df: pd.DataFrame,
|
| 271 |
+
flags: pd.DataFrame,
|
| 272 |
+
station_id: Optional[str] = None,
|
| 273 |
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 274 |
+
"""
|
| 275 |
+
Check values against climatological bounds.
|
| 276 |
+
|
| 277 |
+
Requires climatology data to be loaded first.
|
| 278 |
+
"""
|
| 279 |
+
if self._climatology is None:
|
| 280 |
+
return df, flags
|
| 281 |
+
|
| 282 |
+
cfg = self.config
|
| 283 |
+
|
| 284 |
+
# Get month for each record
|
| 285 |
+
months = df.index.month
|
| 286 |
+
|
| 287 |
+
for col in df.columns:
|
| 288 |
+
if col not in self._climatology.columns:
|
| 289 |
+
continue
|
| 290 |
+
|
| 291 |
+
# Get climatology for each month
|
| 292 |
+
clim_mean = months.map(
|
| 293 |
+
lambda m: self._climatology.loc[m, f"{col}_mean"]
|
| 294 |
+
if m in self._climatology.index
|
| 295 |
+
else np.nan
|
| 296 |
+
)
|
| 297 |
+
clim_std = months.map(
|
| 298 |
+
lambda m: self._climatology.loc[m, f"{col}_std"]
|
| 299 |
+
if m in self._climatology.index
|
| 300 |
+
else np.nan
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# Flag values outside climatological bounds
|
| 304 |
+
deviation = np.abs(df[col] - clim_mean)
|
| 305 |
+
threshold = cfg.climate_std_threshold * clim_std
|
| 306 |
+
|
| 307 |
+
mask = deviation > threshold
|
| 308 |
+
mask = mask & ~df[col].isna()
|
| 309 |
+
|
| 310 |
+
if mask.any():
|
| 311 |
+
flags.loc[mask, f"{col}_flag"] = QCFlag.SUSPECT_CLIMATE.value
|
| 312 |
+
|
| 313 |
+
return df, flags
|
| 314 |
+
|
| 315 |
+
def load_climatology(self, path: str) -> None:
|
| 316 |
+
"""
|
| 317 |
+
Load climatology data for climate checks.
|
| 318 |
+
|
| 319 |
+
Expects a CSV with columns: month, {variable}_mean, {variable}_std
|
| 320 |
+
"""
|
| 321 |
+
self._climatology = pd.read_csv(path, index_col="month")
|
| 322 |
+
logger.info(f"Loaded climatology with {len(self._climatology)} months")
|
| 323 |
+
|
| 324 |
+
def fill_gaps(
|
| 325 |
+
self,
|
| 326 |
+
df: pd.DataFrame,
|
| 327 |
+
method: str = "linear",
|
| 328 |
+
max_gap: Optional[int] = None,
|
| 329 |
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 330 |
+
"""
|
| 331 |
+
Fill small gaps in the data using interpolation.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
df: DataFrame with datetime index
|
| 335 |
+
method: Interpolation method ('linear', 'time', 'spline')
|
| 336 |
+
max_gap: Maximum gap size to fill (uses config default if None)
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
Tuple of (filled_df, flags_df) indicating which values were interpolated
|
| 340 |
+
"""
|
| 341 |
+
if max_gap is None:
|
| 342 |
+
# Determine if hourly or daily based on index frequency
|
| 343 |
+
if len(df) > 1:
|
| 344 |
+
freq = pd.infer_freq(df.index)
|
| 345 |
+
if freq and "H" in freq:
|
| 346 |
+
max_gap = self.config.max_gap_hours
|
| 347 |
+
else:
|
| 348 |
+
max_gap = self.config.max_gap_days
|
| 349 |
+
else:
|
| 350 |
+
max_gap = self.config.max_gap_days
|
| 351 |
+
|
| 352 |
+
# Track which values were interpolated
|
| 353 |
+
was_null = df.isna()
|
| 354 |
+
|
| 355 |
+
# Interpolate
|
| 356 |
+
df_filled = df.interpolate(method=method, limit=max_gap)
|
| 357 |
+
|
| 358 |
+
# Create flags for interpolated values
|
| 359 |
+
flags = pd.DataFrame(index=df.index)
|
| 360 |
+
for col in df.columns:
|
| 361 |
+
flags[f"{col}_flag"] = np.where(
|
| 362 |
+
was_null[col] & ~df_filled[col].isna(),
|
| 363 |
+
QCFlag.GAP_FILLED.value,
|
| 364 |
+
QCFlag.PASSED.value,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
return df_filled, flags
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def main():
|
| 371 |
+
"""CLI entry point for running quality control."""
|
| 372 |
+
import argparse
|
| 373 |
+
|
| 374 |
+
parser = argparse.ArgumentParser(description="Run quality control on weather data")
|
| 375 |
+
parser.add_argument("input", help="Input CSV or Parquet file")
|
| 376 |
+
parser.add_argument("output", help="Output file path")
|
| 377 |
+
parser.add_argument("--climatology", help="Optional climatology file for climate checks")
|
| 378 |
+
|
| 379 |
+
args = parser.parse_args()
|
| 380 |
+
|
| 381 |
+
# Load data
|
| 382 |
+
if args.input.endswith(".parquet"):
|
| 383 |
+
df = pd.read_parquet(args.input)
|
| 384 |
+
else:
|
| 385 |
+
df = pd.read_csv(args.input, index_col=0, parse_dates=True)
|
| 386 |
+
|
| 387 |
+
# Run QC
|
| 388 |
+
qc = QualityController()
|
| 389 |
+
if args.climatology:
|
| 390 |
+
qc.load_climatology(args.climatology)
|
| 391 |
+
|
| 392 |
+
df_clean, flags = qc.process(df)
|
| 393 |
+
|
| 394 |
+
# Save
|
| 395 |
+
if args.output.endswith(".parquet"):
|
| 396 |
+
df_clean.to_parquet(args.output)
|
| 397 |
+
flags.to_parquet(args.output.replace(".parquet", "_flags.parquet"))
|
| 398 |
+
else:
|
| 399 |
+
df_clean.to_csv(args.output)
|
| 400 |
+
flags.to_csv(args.output.replace(".csv", "_flags.csv"))
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
if __name__ == "__main__":
|
| 404 |
+
main()
|
data/raw/ghcn_daily/ghcnd-inventory.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3c15c3a990f8e646d36e8dce7ef68de4c53f9226d1aa5917cd9e9a35ceb4e5f7
|
| 3 |
+
size 35313694
|
data/raw/ghcn_daily/ghcnd-stations.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f8f320b9aa020b8ac7f456ed6af3c96194c7fa8536ddb6937226ef7767b5c8a1
|
| 3 |
+
size 11150588
|
data/raw/ghcn_daily/stations/USC00010063.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010148.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010160.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010163.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010178.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010252.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010260.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010267.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010369.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010377.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010390.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010395.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010402.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010407.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010422.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010425.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010430.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010505.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010583.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010616.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010655.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010757.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010764.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010823.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00010836.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00011069.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00011080.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00011084.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00011099.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00011189.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/ghcn_daily/stations/USC00011288.dly
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|