Spaces:
Sleeping
Sleeping
| """ | |
| Real-time HRRR weather fetcher for the predict-power Space. | |
| This is the runtime counterpart to ``scripts/data_preparation/fetch_hrrr_weather.py`` | |
| (used to build the training set). It MUST produce arrays in the same | |
| shape, channel order, and grid as training, otherwise the model sees an | |
| out-of-distribution input. Specifically: | |
| - 7 channels in fixed order: | |
| [TMP_2m, RH_2m, UGRD_10m, VGRD_10m, GUST_surface, DSWRF_surface, APCP_1hr] | |
| - NE bbox: lat 40.5-47.5 N, lon -74.0 to -66.0 (West) | |
| - Regridded to 450 lat-rows x 449 lon-cols via xarray.interp(linear), | |
| NOT direct slicing of the native Lambert-Conformal grid | |
| We fetch from the public ``noaa-hrrr-bdp-pds`` AWS S3 bucket via the | |
| Herbie library (proven path; same as training). | |
| Two top-level entry points: | |
| - ``fetch_history(end_dt, hours=24)`` returns ``(hours, 450, 449, 7)``, | |
| one f00 analysis per requested hour | |
| - ``fetch_forecast(cycle_dt, hours=24)`` returns ``(hours, 450, 449, 7)``, | |
| cycle_dt's f01..f{hours} forecast hours | |
| Both paths are cached at ``/tmp/hrrr_cache/{cycle_YYYYMMDDHH}_f{NN}.npz``. | |
| The cache survives within an HF Space uptime session and is wiped on sleep. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from datetime import datetime, timedelta, timezone | |
| from pathlib import Path | |
| from typing import Callable, Iterable, Optional, Sequence | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| # === Match training pipeline EXACTLY === | |
| _BBOX = {"lat_min": 40.5, "lat_max": 47.5, | |
| "lon_min": -74.0, "lon_max": -66.0} | |
| GRID_H = 450 # lat rows | |
| GRID_W = 449 # lon cols | |
| N_CHANNELS = 7 | |
| # Target lat/lon grid (geographic, not native HRRR Lambert-Conformal) | |
| _LAT = np.linspace(_BBOX["lat_min"], _BBOX["lat_max"], GRID_H) | |
| _LON = np.linspace(_BBOX["lon_min"], _BBOX["lon_max"], GRID_W) | |
| # Channel definitions: (name, herbie searchString) | |
| _CHANNELS: list[tuple[str, str]] = [ | |
| ("TMP", ":TMP:2 m above ground"), | |
| ("RH", ":RH:2 m above ground"), | |
| ("UGRD", ":UGRD:10 m above ground"), | |
| ("VGRD", ":VGRD:10 m above ground"), | |
| ("GUST", ":GUST:surface"), | |
| ("DSWRF", ":DSWRF:surface"), | |
| ("APCP_1hr", ":APCP:surface:0-1 hour acc"), | |
| ] | |
| CACHE_DIR = Path(os.environ.get("HRRR_CACHE_DIR", "/tmp/hrrr_cache")) | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| def _cache_path(cycle_dt: datetime, fxx: int) -> Path: | |
| return CACHE_DIR / f"{cycle_dt.strftime('%Y%m%d%H')}_f{fxx:02d}.npz" | |
| def _hour_floor_utc(dt: datetime) -> datetime: | |
| if dt.tzinfo is None: | |
| dt = dt.replace(tzinfo=timezone.utc) | |
| dt = dt.astimezone(timezone.utc) | |
| return dt.replace(minute=0, second=0, microsecond=0, tzinfo=None) | |
| # --- regridding weights (computed lazily, then cached for the process) --- | |
| # HRRR's native Lambert-Conformal grid is fixed across cycles, so we can | |
| # precompute (mask, kdtree, weights, idxs) once from any sample dataset. | |
| # Per-channel regrid is then a single matmul (~10 ms on cpu-basic). | |
| _REGRID_CACHE: dict = {} | |
| def _build_regrid_weights(lat2d: np.ndarray, lon2d_signed: np.ndarray): | |
| """Build a precomputed Delaunay-triangulation-based linear interpolator | |
| for our (1059, 1799) HRRR Lambert grid -> (450, 449) regular lat/lon | |
| target grid. Matches xarray.interp(method="linear") used at training. | |
| Returns dict with keys: | |
| - ``mask``: bool array (1059, 1799) selecting cells inside an NE | |
| bounding box that contains our target grid with ~1.5deg margin | |
| - ``simplex``: (450*449,) int32 — Delaunay simplex index for each | |
| target cell, or -1 if outside the convex hull | |
| - ``bary``: (450*449, 3) float32 — barycentric weights inside the | |
| simplex (sum to 1 along axis=1) | |
| - ``vertices``: (n_simplex, 3) int32 — vertex indices into the | |
| masked source array | |
| Per-cell evaluation is then `(values[vertices[simplex[i]]] * | |
| bary[i]).sum()`, which is mathematically equivalent to bilinear | |
| interpolation on a triangulated irregular grid. ~10s setup, | |
| ~10ms per channel after. | |
| """ | |
| from scipy.spatial import Delaunay # noqa: WPS433 | |
| # Crop with margin so target-grid corners always have neighbors in source | |
| mask = ((lat2d >= _BBOX["lat_min"] - 1.5) | |
| & (lat2d <= _BBOX["lat_max"] + 1.5) | |
| & (lon2d_signed >= _BBOX["lon_min"] - 1.5) | |
| & (lon2d_signed <= _BBOX["lon_max"] + 1.5)) | |
| if mask.sum() == 0: | |
| raise RuntimeError("Bounding-box mask is empty; HRRR grid mismatch?") | |
| src_pts = np.stack( | |
| [lat2d[mask].astype(np.float64), | |
| lon2d_signed[mask].astype(np.float64)], | |
| axis=-1) | |
| LL, LN = np.meshgrid(_LAT, _LON, indexing="ij") | |
| tgt_pts = np.stack([LL.ravel(), LN.ravel()], axis=-1) | |
| tri = Delaunay(src_pts) | |
| simplex = tri.find_simplex(tgt_pts) | |
| if (simplex < 0).any(): | |
| n_outside = int((simplex < 0).sum()) | |
| logger.warning( | |
| " %d of %d target cells fall outside the source convex hull; " | |
| "filling those with nearest-neighbor", n_outside, simplex.size) | |
| # Barycentric weights for each target point inside its simplex | |
| # (vectorized via tri.transform) | |
| X = tri.transform[simplex, :2] # (N, 2, 2) | |
| Y = tgt_pts - tri.transform[simplex, 2] # (N, 2) | |
| bary_in = np.einsum("ijk,ik->ij", X, Y) # (N, 2) | |
| bary_full = np.concatenate( | |
| [bary_in, 1 - bary_in.sum(axis=1, keepdims=True)], axis=1) # (N, 3) | |
| # For points outside the hull, fall back to nearest-neighbor (give that | |
| # neighbor weight 1 and the other two 0, with vertex index = nearest). | |
| if (simplex < 0).any(): | |
| from scipy.spatial import cKDTree # noqa: WPS433 | |
| oob_mask = simplex < 0 | |
| tree = cKDTree(src_pts) | |
| _, nn_idx = tree.query(tgt_pts[oob_mask], k=1) | |
| # Use a dummy simplex (any valid one) for shape; we'll override | |
| # vertices below | |
| simplex[oob_mask] = 0 | |
| bary_full[oob_mask] = [1.0, 0.0, 0.0] | |
| vertices = tri.simplices.copy().astype(np.int32) | |
| # Build a per-target vertex array (3 vertex indices per target cell) | |
| verts_per_target = vertices[simplex].copy() # (N, 3) | |
| # Override OOB cells: set first vertex to NN, others arbitrary (weights 0) | |
| verts_per_target[oob_mask, 0] = nn_idx | |
| else: | |
| vertices = tri.simplices.astype(np.int32) | |
| verts_per_target = vertices[simplex].copy() | |
| return { | |
| "mask": mask, | |
| "vertices_per_target": verts_per_target.astype(np.int32), # (N, 3) | |
| "weights": bary_full.astype(np.float32), # (N, 3) | |
| } | |
| def _regrid(field2d: np.ndarray, weights_pack: dict) -> np.ndarray: | |
| """Apply precomputed Delaunay barycentric weights to a (1059, 1799) | |
| HRRR field; return (450, 449) float32 on the regular lat/lon grid.""" | |
| cropped = field2d[weights_pack["mask"]].astype(np.float32) | |
| # Gather vertex values then multiply by barycentric weights | |
| out = (cropped[weights_pack["vertices_per_target"]] | |
| * weights_pack["weights"]).sum(axis=1) | |
| return out.reshape(GRID_H, GRID_W) | |
| def _fetch_one_via_herbie(cycle_dt: datetime, fxx: int) -> np.ndarray: | |
| """Fetch one (cycle, forecast-hour) pair, return (450, 449, 7) float32. | |
| Caller is responsible for caching; this function always hits the network. | |
| Raises RuntimeError on any failure. | |
| """ | |
| try: | |
| from herbie import Herbie # noqa: WPS433 (optional heavy dep) | |
| except ImportError as e: | |
| raise RuntimeError( | |
| f"hrrr_fetch.py requires herbie-data: {e}") from e | |
| H = Herbie( | |
| cycle_dt.strftime("%Y-%m-%d %H:00"), | |
| model="hrrr", | |
| product="sfc", | |
| fxx=fxx, | |
| verbose=False, | |
| ) | |
| channels: list[np.ndarray] = [] | |
| for ch_name, regex in _CHANNELS: | |
| try: | |
| # Newer Herbie (>=2024.x) renamed `searchString` to `search` | |
| ds = H.xarray(search=regex, verbose=False) | |
| except Exception as e: # noqa: BLE001 | |
| # APCP accumulation window varies with forecast hour: | |
| # f00 has no APCP, f01 has "0-1 hour acc" (matches our regex), | |
| # f02 has "0-2 hour acc" or "1-2 hour acc", etc. We zero-fill | |
| # any APCP fetch failure (the training mean is near zero in | |
| # MM units anyway, so post-z-score the model sees ~0). | |
| if ch_name == "APCP_1hr": | |
| logger.info("APCP_1hr unavailable at %s f%02d (%s); using zero", | |
| cycle_dt, fxx, | |
| type(e).__name__ if not str(e) else str(e)[:80]) | |
| channels.append(np.zeros((GRID_H, GRID_W), dtype=np.float32)) | |
| continue | |
| raise RuntimeError( | |
| f"Herbie xarray() failed for {ch_name} at " | |
| f"{cycle_dt.isoformat()} f{fxx:02d}: {e}") from e | |
| var = list(ds.data_vars)[0] | |
| arr = ds[var] | |
| field2d = np.squeeze(arr.values) | |
| if field2d.shape != (1059, 1799): | |
| raise RuntimeError( | |
| f"unexpected HRRR field shape {field2d.shape} for {ch_name}") | |
| # Initialize regrid weights once per process from the first dataset | |
| if "weights_pack" not in _REGRID_CACHE: | |
| lat2d = arr.coords["latitude"].values | |
| lon2d = arr.coords["longitude"].values | |
| lon2d_signed = np.where(lon2d > 180, lon2d - 360, lon2d) | |
| _REGRID_CACHE["weights_pack"] = _build_regrid_weights( | |
| lat2d, lon2d_signed) | |
| logger.info("Built HRRR -> NE-grid regrid weights " | |
| "(one-time setup, ~0.3s)") | |
| regridded = _regrid(field2d, _REGRID_CACHE["weights_pack"]) | |
| channels.append(regridded.astype(np.float32)) | |
| tensor = np.stack(channels, axis=-1) | |
| if np.isnan(tensor).any(): | |
| raise RuntimeError( | |
| f"NaN in regridded HRRR tensor for " | |
| f"{cycle_dt.isoformat()} f{fxx:02d}") | |
| return tensor | |
| def _fetch_with_cache(cycle_dt: datetime, fxx: int) -> np.ndarray: | |
| """Fetch one (cycle, fxx) pair via cache or network.""" | |
| p = _cache_path(cycle_dt, fxx) | |
| if p.exists(): | |
| try: | |
| with np.load(p) as f: | |
| return f["weather"].astype(np.float32) | |
| except Exception: # corrupt cache file, refetch | |
| p.unlink(missing_ok=True) | |
| tensor = _fetch_one_via_herbie(cycle_dt, fxx) | |
| # Store as float16 to halve disk usage (~2.8 MB/file vs 5.6 MB) | |
| np.savez_compressed(p, weather=tensor.astype(np.float16)) | |
| return tensor | |
| def _fetch_parallel(jobs: Sequence[tuple[datetime, int]], | |
| parallel: int = 8, | |
| progress: Optional[Callable[[int, int, str], None]] = None, | |
| ) -> dict[tuple[datetime, int], np.ndarray]: | |
| """Fetch many (cycle_dt, fxx) pairs in parallel; return dict by job key.""" | |
| if not jobs: | |
| return {} | |
| out: dict[tuple[datetime, int], np.ndarray] = {} | |
| if parallel <= 1: | |
| for i, (cdt, fxx) in enumerate(jobs): | |
| out[(cdt, fxx)] = _fetch_with_cache(cdt, fxx) | |
| if progress: | |
| progress(i + 1, len(jobs), f"{cdt.strftime('%Y-%m-%d %H')} f{fxx:02d}") | |
| return out | |
| with ThreadPoolExecutor(max_workers=parallel) as ex: | |
| futures = {ex.submit(_fetch_with_cache, cdt, fxx): (cdt, fxx) | |
| for cdt, fxx in jobs} | |
| completed = 0 | |
| for fut in as_completed(futures): | |
| key = futures[fut] | |
| out[key] = fut.result() | |
| completed += 1 | |
| if progress: | |
| cdt, fxx = key | |
| progress(completed, len(jobs), | |
| f"{cdt.strftime('%Y-%m-%d %H')} f{fxx:02d}") | |
| return out | |
| # ===================================================================== | |
| # Public API | |
| # ===================================================================== | |
| def fetch_history(end_dt: datetime, hours: int = 24, | |
| parallel: int = 8, | |
| progress: Optional[Callable[[int, int, str], None]] = None, | |
| ) -> np.ndarray: | |
| """Return ``(hours, 450, 449, 7)`` float32 of HRRR f00 analyses for | |
| the inclusive window ``[end_dt - hours, end_dt - 1h]``. | |
| Each requested valid-hour ``H`` uses cycle ``H`` with fxx=0 (i.e., | |
| the analysis at that valid hour), matching how the training data | |
| was constructed. | |
| """ | |
| end_dt = _hour_floor_utc(end_dt) | |
| valid_hours = [end_dt - timedelta(hours=hours - i) for i in range(hours)] | |
| jobs = [(vh, 0) for vh in valid_hours] | |
| fetched = _fetch_parallel(jobs, parallel=parallel, progress=progress) | |
| out = np.stack([fetched[(vh, 0)] for vh in valid_hours], axis=0) | |
| return out | |
| # HRRR cycles with extended (0-48 h) forecasts. Other hourly cycles | |
| # (01/02/04/05/...) only go out to f18, so we can't get 24 h from them. | |
| LONG_CYCLE_HOURS = (0, 6, 12, 18) | |
| def _latest_long_cycle_le(dt: datetime) -> datetime: | |
| """Return the most recent HRRR long cycle (00/06/12/18 UTC) <= dt.""" | |
| dt = _hour_floor_utc(dt) | |
| while dt.hour not in LONG_CYCLE_HOURS: | |
| dt -= timedelta(hours=1) | |
| return dt | |
| def fetch_forecast_for_window(target_start: datetime, hours: int = 24, | |
| publication_lag_hours: int = 2, | |
| parallel: int = 8, | |
| progress: Optional[Callable[[int, int, str], None]] = None, | |
| ) -> tuple[np.ndarray, datetime, int]: | |
| """Return ``(hours, 450, 449, 7)`` covering valid hours | |
| ``[target_start, target_start + hours - 1]``, using the most recent | |
| HRRR long cycle (one of 00/06/12/18 UTC) that was published before | |
| ``target_start`` (with ``publication_lag_hours`` margin to allow for | |
| cycle processing delay). | |
| Returns ``(weather, cycle_dt, fxx_start)`` so the caller can log | |
| which cycle was used. | |
| """ | |
| target_start = _hour_floor_utc(target_start) | |
| cutoff = target_start - timedelta(hours=publication_lag_hours) | |
| cycle_dt = _latest_long_cycle_le(cutoff) | |
| fxx_start = int((target_start - cycle_dt).total_seconds() / 3600) | |
| jobs = [(cycle_dt, fxx) for fxx in range(fxx_start, fxx_start + hours)] | |
| fetched = _fetch_parallel(jobs, parallel=parallel, progress=progress) | |
| out = np.stack([fetched[(cycle_dt, fxx)] | |
| for fxx in range(fxx_start, fxx_start + hours)], axis=0) | |
| return out, cycle_dt, fxx_start | |
| def fetch_forecast(cycle_dt: datetime, hours: int = 24, | |
| parallel: int = 8, | |
| progress: Optional[Callable[[int, int, str], None]] = None, | |
| ) -> np.ndarray: | |
| """Backwards-compat wrapper: fetch f01..f{hours} from a specific cycle. | |
| NOTE: only long cycles (00/06/12/18 UTC) reliably cover 24+ hours. | |
| For automatic cycle selection, prefer ``fetch_forecast_for_window``. | |
| """ | |
| cycle_dt = _hour_floor_utc(cycle_dt) | |
| jobs = [(cycle_dt, fxx) for fxx in range(1, hours + 1)] | |
| fetched = _fetch_parallel(jobs, parallel=parallel, progress=progress) | |
| out = np.stack([fetched[(cycle_dt, fxx)] for fxx in range(1, hours + 1)], | |
| axis=0) | |
| return out | |
| def latest_available_cycle(target_dt: datetime, | |
| max_lookback_hours: int = 4, | |
| ) -> datetime: | |
| """Find the most recent HRRR cycle <= ``target_dt`` whose forecast | |
| hours appear to be on S3 (HRRR has ~1-2 hour publication lag). | |
| We probe by trying to instantiate Herbie for each cycle from | |
| ``target_dt`` backwards, succeeding when ``H.grib`` resolves. | |
| Returns the cycle datetime (UTC, hour-floored, naive). | |
| """ | |
| target_dt = _hour_floor_utc(target_dt) | |
| try: | |
| from herbie import Herbie # noqa: WPS433 | |
| except ImportError as e: | |
| raise RuntimeError(f"herbie-data not installed: {e}") from e | |
| for back in range(0, max_lookback_hours + 1): | |
| cdt = target_dt - timedelta(hours=back) | |
| try: | |
| H = Herbie(cdt.strftime("%Y-%m-%d %H:00"), | |
| model="hrrr", product="sfc", fxx=1, verbose=False) | |
| if H.grib is not None: | |
| return cdt | |
| except Exception: # noqa: BLE001 | |
| continue | |
| raise RuntimeError( | |
| f"No HRRR cycle available within last {max_lookback_hours}h of " | |
| f"{target_dt.isoformat()}") | |
| if __name__ == "__main__": | |
| # Smoke test: fetch one f00 + one f01 from yesterday's noon cycle | |
| logging.basicConfig(level=logging.INFO, format="%(message)s") | |
| yesterday_noon = (datetime.now(timezone.utc) - timedelta(days=1) | |
| ).replace(hour=12, minute=0, second=0, microsecond=0, | |
| tzinfo=None) | |
| print(f"Smoke test cycle: {yesterday_noon} UTC") | |
| arr = _fetch_with_cache(yesterday_noon, 0) | |
| print(f" f00: shape={arr.shape}, dtype={arr.dtype}, " | |
| f"mean per channel: " + ", ".join( | |
| f"{name}={arr[..., i].mean():.2f}" for i, (name, _) in enumerate(_CHANNELS))) | |
| arr1 = _fetch_with_cache(yesterday_noon, 1) | |
| print(f" f01: shape={arr1.shape}, dtype={arr1.dtype}, " | |
| f"mean per channel: " + ", ".join( | |
| f"{name}={arr1[..., i].mean():.2f}" for i, (name, _) in enumerate(_CHANNELS))) | |
| print(f" cache dir: {CACHE_DIR}, n files: {len(list(CACHE_DIR.glob('*.npz')))}") | |