predict-power / hrrr_fetch.py
jeffliulab's picture
Docs refresh + dynamic live MAPE in About tab + 3 new figures
95fc2f5 verified
"""
Real-time HRRR weather fetcher for the predict-power Space.
This is the runtime counterpart to ``scripts/data_preparation/fetch_hrrr_weather.py``
(used to build the training set). It MUST produce arrays in the same
shape, channel order, and grid as training, otherwise the model sees an
out-of-distribution input. Specifically:
- 7 channels in fixed order:
[TMP_2m, RH_2m, UGRD_10m, VGRD_10m, GUST_surface, DSWRF_surface, APCP_1hr]
- NE bbox: lat 40.5-47.5 N, lon -74.0 to -66.0 (West)
- Regridded to 450 lat-rows x 449 lon-cols via xarray.interp(linear),
NOT direct slicing of the native Lambert-Conformal grid
We fetch from the public ``noaa-hrrr-bdp-pds`` AWS S3 bucket via the
Herbie library (proven path; same as training).
Two top-level entry points:
- ``fetch_history(end_dt, hours=24)`` returns ``(hours, 450, 449, 7)``,
one f00 analysis per requested hour
- ``fetch_forecast(cycle_dt, hours=24)`` returns ``(hours, 450, 449, 7)``,
cycle_dt's f01..f{hours} forecast hours
Both paths are cached at ``/tmp/hrrr_cache/{cycle_YYYYMMDDHH}_f{NN}.npz``.
The cache survives within an HF Space uptime session and is wiped on sleep.
"""
from __future__ import annotations
import logging
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Callable, Iterable, Optional, Sequence
import numpy as np
logger = logging.getLogger(__name__)
# === Match training pipeline EXACTLY ===
_BBOX = {"lat_min": 40.5, "lat_max": 47.5,
"lon_min": -74.0, "lon_max": -66.0}
GRID_H = 450 # lat rows
GRID_W = 449 # lon cols
N_CHANNELS = 7
# Target lat/lon grid (geographic, not native HRRR Lambert-Conformal)
_LAT = np.linspace(_BBOX["lat_min"], _BBOX["lat_max"], GRID_H)
_LON = np.linspace(_BBOX["lon_min"], _BBOX["lon_max"], GRID_W)
# Channel definitions: (name, herbie searchString)
_CHANNELS: list[tuple[str, str]] = [
("TMP", ":TMP:2 m above ground"),
("RH", ":RH:2 m above ground"),
("UGRD", ":UGRD:10 m above ground"),
("VGRD", ":VGRD:10 m above ground"),
("GUST", ":GUST:surface"),
("DSWRF", ":DSWRF:surface"),
("APCP_1hr", ":APCP:surface:0-1 hour acc"),
]
CACHE_DIR = Path(os.environ.get("HRRR_CACHE_DIR", "/tmp/hrrr_cache"))
CACHE_DIR.mkdir(parents=True, exist_ok=True)
def _cache_path(cycle_dt: datetime, fxx: int) -> Path:
return CACHE_DIR / f"{cycle_dt.strftime('%Y%m%d%H')}_f{fxx:02d}.npz"
def _hour_floor_utc(dt: datetime) -> datetime:
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
dt = dt.astimezone(timezone.utc)
return dt.replace(minute=0, second=0, microsecond=0, tzinfo=None)
# --- regridding weights (computed lazily, then cached for the process) ---
# HRRR's native Lambert-Conformal grid is fixed across cycles, so we can
# precompute (mask, kdtree, weights, idxs) once from any sample dataset.
# Per-channel regrid is then a single matmul (~10 ms on cpu-basic).
_REGRID_CACHE: dict = {}
def _build_regrid_weights(lat2d: np.ndarray, lon2d_signed: np.ndarray):
"""Build a precomputed Delaunay-triangulation-based linear interpolator
for our (1059, 1799) HRRR Lambert grid -> (450, 449) regular lat/lon
target grid. Matches xarray.interp(method="linear") used at training.
Returns dict with keys:
- ``mask``: bool array (1059, 1799) selecting cells inside an NE
bounding box that contains our target grid with ~1.5deg margin
- ``simplex``: (450*449,) int32 — Delaunay simplex index for each
target cell, or -1 if outside the convex hull
- ``bary``: (450*449, 3) float32 — barycentric weights inside the
simplex (sum to 1 along axis=1)
- ``vertices``: (n_simplex, 3) int32 — vertex indices into the
masked source array
Per-cell evaluation is then `(values[vertices[simplex[i]]] *
bary[i]).sum()`, which is mathematically equivalent to bilinear
interpolation on a triangulated irregular grid. ~10s setup,
~10ms per channel after.
"""
from scipy.spatial import Delaunay # noqa: WPS433
# Crop with margin so target-grid corners always have neighbors in source
mask = ((lat2d >= _BBOX["lat_min"] - 1.5)
& (lat2d <= _BBOX["lat_max"] + 1.5)
& (lon2d_signed >= _BBOX["lon_min"] - 1.5)
& (lon2d_signed <= _BBOX["lon_max"] + 1.5))
if mask.sum() == 0:
raise RuntimeError("Bounding-box mask is empty; HRRR grid mismatch?")
src_pts = np.stack(
[lat2d[mask].astype(np.float64),
lon2d_signed[mask].astype(np.float64)],
axis=-1)
LL, LN = np.meshgrid(_LAT, _LON, indexing="ij")
tgt_pts = np.stack([LL.ravel(), LN.ravel()], axis=-1)
tri = Delaunay(src_pts)
simplex = tri.find_simplex(tgt_pts)
if (simplex < 0).any():
n_outside = int((simplex < 0).sum())
logger.warning(
" %d of %d target cells fall outside the source convex hull; "
"filling those with nearest-neighbor", n_outside, simplex.size)
# Barycentric weights for each target point inside its simplex
# (vectorized via tri.transform)
X = tri.transform[simplex, :2] # (N, 2, 2)
Y = tgt_pts - tri.transform[simplex, 2] # (N, 2)
bary_in = np.einsum("ijk,ik->ij", X, Y) # (N, 2)
bary_full = np.concatenate(
[bary_in, 1 - bary_in.sum(axis=1, keepdims=True)], axis=1) # (N, 3)
# For points outside the hull, fall back to nearest-neighbor (give that
# neighbor weight 1 and the other two 0, with vertex index = nearest).
if (simplex < 0).any():
from scipy.spatial import cKDTree # noqa: WPS433
oob_mask = simplex < 0
tree = cKDTree(src_pts)
_, nn_idx = tree.query(tgt_pts[oob_mask], k=1)
# Use a dummy simplex (any valid one) for shape; we'll override
# vertices below
simplex[oob_mask] = 0
bary_full[oob_mask] = [1.0, 0.0, 0.0]
vertices = tri.simplices.copy().astype(np.int32)
# Build a per-target vertex array (3 vertex indices per target cell)
verts_per_target = vertices[simplex].copy() # (N, 3)
# Override OOB cells: set first vertex to NN, others arbitrary (weights 0)
verts_per_target[oob_mask, 0] = nn_idx
else:
vertices = tri.simplices.astype(np.int32)
verts_per_target = vertices[simplex].copy()
return {
"mask": mask,
"vertices_per_target": verts_per_target.astype(np.int32), # (N, 3)
"weights": bary_full.astype(np.float32), # (N, 3)
}
def _regrid(field2d: np.ndarray, weights_pack: dict) -> np.ndarray:
"""Apply precomputed Delaunay barycentric weights to a (1059, 1799)
HRRR field; return (450, 449) float32 on the regular lat/lon grid."""
cropped = field2d[weights_pack["mask"]].astype(np.float32)
# Gather vertex values then multiply by barycentric weights
out = (cropped[weights_pack["vertices_per_target"]]
* weights_pack["weights"]).sum(axis=1)
return out.reshape(GRID_H, GRID_W)
def _fetch_one_via_herbie(cycle_dt: datetime, fxx: int) -> np.ndarray:
"""Fetch one (cycle, forecast-hour) pair, return (450, 449, 7) float32.
Caller is responsible for caching; this function always hits the network.
Raises RuntimeError on any failure.
"""
try:
from herbie import Herbie # noqa: WPS433 (optional heavy dep)
except ImportError as e:
raise RuntimeError(
f"hrrr_fetch.py requires herbie-data: {e}") from e
H = Herbie(
cycle_dt.strftime("%Y-%m-%d %H:00"),
model="hrrr",
product="sfc",
fxx=fxx,
verbose=False,
)
channels: list[np.ndarray] = []
for ch_name, regex in _CHANNELS:
try:
# Newer Herbie (>=2024.x) renamed `searchString` to `search`
ds = H.xarray(search=regex, verbose=False)
except Exception as e: # noqa: BLE001
# APCP accumulation window varies with forecast hour:
# f00 has no APCP, f01 has "0-1 hour acc" (matches our regex),
# f02 has "0-2 hour acc" or "1-2 hour acc", etc. We zero-fill
# any APCP fetch failure (the training mean is near zero in
# MM units anyway, so post-z-score the model sees ~0).
if ch_name == "APCP_1hr":
logger.info("APCP_1hr unavailable at %s f%02d (%s); using zero",
cycle_dt, fxx,
type(e).__name__ if not str(e) else str(e)[:80])
channels.append(np.zeros((GRID_H, GRID_W), dtype=np.float32))
continue
raise RuntimeError(
f"Herbie xarray() failed for {ch_name} at "
f"{cycle_dt.isoformat()} f{fxx:02d}: {e}") from e
var = list(ds.data_vars)[0]
arr = ds[var]
field2d = np.squeeze(arr.values)
if field2d.shape != (1059, 1799):
raise RuntimeError(
f"unexpected HRRR field shape {field2d.shape} for {ch_name}")
# Initialize regrid weights once per process from the first dataset
if "weights_pack" not in _REGRID_CACHE:
lat2d = arr.coords["latitude"].values
lon2d = arr.coords["longitude"].values
lon2d_signed = np.where(lon2d > 180, lon2d - 360, lon2d)
_REGRID_CACHE["weights_pack"] = _build_regrid_weights(
lat2d, lon2d_signed)
logger.info("Built HRRR -> NE-grid regrid weights "
"(one-time setup, ~0.3s)")
regridded = _regrid(field2d, _REGRID_CACHE["weights_pack"])
channels.append(regridded.astype(np.float32))
tensor = np.stack(channels, axis=-1)
if np.isnan(tensor).any():
raise RuntimeError(
f"NaN in regridded HRRR tensor for "
f"{cycle_dt.isoformat()} f{fxx:02d}")
return tensor
def _fetch_with_cache(cycle_dt: datetime, fxx: int) -> np.ndarray:
"""Fetch one (cycle, fxx) pair via cache or network."""
p = _cache_path(cycle_dt, fxx)
if p.exists():
try:
with np.load(p) as f:
return f["weather"].astype(np.float32)
except Exception: # corrupt cache file, refetch
p.unlink(missing_ok=True)
tensor = _fetch_one_via_herbie(cycle_dt, fxx)
# Store as float16 to halve disk usage (~2.8 MB/file vs 5.6 MB)
np.savez_compressed(p, weather=tensor.astype(np.float16))
return tensor
def _fetch_parallel(jobs: Sequence[tuple[datetime, int]],
parallel: int = 8,
progress: Optional[Callable[[int, int, str], None]] = None,
) -> dict[tuple[datetime, int], np.ndarray]:
"""Fetch many (cycle_dt, fxx) pairs in parallel; return dict by job key."""
if not jobs:
return {}
out: dict[tuple[datetime, int], np.ndarray] = {}
if parallel <= 1:
for i, (cdt, fxx) in enumerate(jobs):
out[(cdt, fxx)] = _fetch_with_cache(cdt, fxx)
if progress:
progress(i + 1, len(jobs), f"{cdt.strftime('%Y-%m-%d %H')} f{fxx:02d}")
return out
with ThreadPoolExecutor(max_workers=parallel) as ex:
futures = {ex.submit(_fetch_with_cache, cdt, fxx): (cdt, fxx)
for cdt, fxx in jobs}
completed = 0
for fut in as_completed(futures):
key = futures[fut]
out[key] = fut.result()
completed += 1
if progress:
cdt, fxx = key
progress(completed, len(jobs),
f"{cdt.strftime('%Y-%m-%d %H')} f{fxx:02d}")
return out
# =====================================================================
# Public API
# =====================================================================
def fetch_history(end_dt: datetime, hours: int = 24,
parallel: int = 8,
progress: Optional[Callable[[int, int, str], None]] = None,
) -> np.ndarray:
"""Return ``(hours, 450, 449, 7)`` float32 of HRRR f00 analyses for
the inclusive window ``[end_dt - hours, end_dt - 1h]``.
Each requested valid-hour ``H`` uses cycle ``H`` with fxx=0 (i.e.,
the analysis at that valid hour), matching how the training data
was constructed.
"""
end_dt = _hour_floor_utc(end_dt)
valid_hours = [end_dt - timedelta(hours=hours - i) for i in range(hours)]
jobs = [(vh, 0) for vh in valid_hours]
fetched = _fetch_parallel(jobs, parallel=parallel, progress=progress)
out = np.stack([fetched[(vh, 0)] for vh in valid_hours], axis=0)
return out
# HRRR cycles with extended (0-48 h) forecasts. Other hourly cycles
# (01/02/04/05/...) only go out to f18, so we can't get 24 h from them.
LONG_CYCLE_HOURS = (0, 6, 12, 18)
def _latest_long_cycle_le(dt: datetime) -> datetime:
"""Return the most recent HRRR long cycle (00/06/12/18 UTC) <= dt."""
dt = _hour_floor_utc(dt)
while dt.hour not in LONG_CYCLE_HOURS:
dt -= timedelta(hours=1)
return dt
def fetch_forecast_for_window(target_start: datetime, hours: int = 24,
publication_lag_hours: int = 2,
parallel: int = 8,
progress: Optional[Callable[[int, int, str], None]] = None,
) -> tuple[np.ndarray, datetime, int]:
"""Return ``(hours, 450, 449, 7)`` covering valid hours
``[target_start, target_start + hours - 1]``, using the most recent
HRRR long cycle (one of 00/06/12/18 UTC) that was published before
``target_start`` (with ``publication_lag_hours`` margin to allow for
cycle processing delay).
Returns ``(weather, cycle_dt, fxx_start)`` so the caller can log
which cycle was used.
"""
target_start = _hour_floor_utc(target_start)
cutoff = target_start - timedelta(hours=publication_lag_hours)
cycle_dt = _latest_long_cycle_le(cutoff)
fxx_start = int((target_start - cycle_dt).total_seconds() / 3600)
jobs = [(cycle_dt, fxx) for fxx in range(fxx_start, fxx_start + hours)]
fetched = _fetch_parallel(jobs, parallel=parallel, progress=progress)
out = np.stack([fetched[(cycle_dt, fxx)]
for fxx in range(fxx_start, fxx_start + hours)], axis=0)
return out, cycle_dt, fxx_start
def fetch_forecast(cycle_dt: datetime, hours: int = 24,
parallel: int = 8,
progress: Optional[Callable[[int, int, str], None]] = None,
) -> np.ndarray:
"""Backwards-compat wrapper: fetch f01..f{hours} from a specific cycle.
NOTE: only long cycles (00/06/12/18 UTC) reliably cover 24+ hours.
For automatic cycle selection, prefer ``fetch_forecast_for_window``.
"""
cycle_dt = _hour_floor_utc(cycle_dt)
jobs = [(cycle_dt, fxx) for fxx in range(1, hours + 1)]
fetched = _fetch_parallel(jobs, parallel=parallel, progress=progress)
out = np.stack([fetched[(cycle_dt, fxx)] for fxx in range(1, hours + 1)],
axis=0)
return out
def latest_available_cycle(target_dt: datetime,
max_lookback_hours: int = 4,
) -> datetime:
"""Find the most recent HRRR cycle <= ``target_dt`` whose forecast
hours appear to be on S3 (HRRR has ~1-2 hour publication lag).
We probe by trying to instantiate Herbie for each cycle from
``target_dt`` backwards, succeeding when ``H.grib`` resolves.
Returns the cycle datetime (UTC, hour-floored, naive).
"""
target_dt = _hour_floor_utc(target_dt)
try:
from herbie import Herbie # noqa: WPS433
except ImportError as e:
raise RuntimeError(f"herbie-data not installed: {e}") from e
for back in range(0, max_lookback_hours + 1):
cdt = target_dt - timedelta(hours=back)
try:
H = Herbie(cdt.strftime("%Y-%m-%d %H:00"),
model="hrrr", product="sfc", fxx=1, verbose=False)
if H.grib is not None:
return cdt
except Exception: # noqa: BLE001
continue
raise RuntimeError(
f"No HRRR cycle available within last {max_lookback_hours}h of "
f"{target_dt.isoformat()}")
if __name__ == "__main__":
# Smoke test: fetch one f00 + one f01 from yesterday's noon cycle
logging.basicConfig(level=logging.INFO, format="%(message)s")
yesterday_noon = (datetime.now(timezone.utc) - timedelta(days=1)
).replace(hour=12, minute=0, second=0, microsecond=0,
tzinfo=None)
print(f"Smoke test cycle: {yesterday_noon} UTC")
arr = _fetch_with_cache(yesterday_noon, 0)
print(f" f00: shape={arr.shape}, dtype={arr.dtype}, "
f"mean per channel: " + ", ".join(
f"{name}={arr[..., i].mean():.2f}" for i, (name, _) in enumerate(_CHANNELS)))
arr1 = _fetch_with_cache(yesterday_noon, 1)
print(f" f01: shape={arr1.shape}, dtype={arr1.dtype}, "
f"mean per channel: " + ", ".join(
f"{name}={arr1[..., i].mean():.2f}" for i, (name, _) in enumerate(_CHANNELS)))
print(f" cache dir: {CACHE_DIR}, n files: {len(list(CACHE_DIR.glob('*.npz')))}")