StormCast CONUS Regression Model (Full Domain)

Description

StormCast CONUS is a regression UNet model for high-resolution (3km) weather prediction over the full Continental United States (CONUS) domain. It is trained as the first stage of the StormCast regression-diffusion framework, which autoregressively predicts 99 state variables at km scale using a 1-hour time step, with dense vertical resolution in the atmospheric boundary layer.

This model extends the original StormCast V1 to the full CONUS domain (1056 x 1792 grid) rather than the smaller central US bounding box (512 x 640) used in the original paper.

This checkpoint contains the regression model only. The full StormCast pipeline additionally uses an EDM diffusion model to add fine-grained stochastic structure (storm cells, precipitation bands).

For training recipes see NVIDIA PhysicsNeMo, for inference see NVIDIA Earth2Studio.

This model is for research and development only.

License/Terms of Use

This model is distributed under the Apache 2.0 license.

Reference(s)

Kilometer-Scale Convection Allowing Model Emulation using Generative Diffusion Modeling

Model Architecture

Architecture Type: StormCast uses a UNet architecture in a regression-diffusion generative model framework.
Network Architecture: UNet

Parameter Value
Model channels 64
Channel multiplier [1, 2, 4, 4]
Blocks per level 2
Gradient checkpointing Level 2
Spatial positional embedding Yes
Precision BF16 (AMP)

Input

Input Type(s):

  • Tensor (127 channels: 99 state variables + 26 conditioning variables + 2 invariants)
  • DateTime (NumPy Array)

Input Format(s): PyTorch Tensor / NumPy array
Input Parameters:

  • Four Dimensional (4D) (batch, variable, latitude, longitude)
  • Input DateTime (1D)

Other Properties Related to Input:

  • Input grid: HRRR Lambert Conformal projection, full CONUS domain (1056 x 1792 at 3km)
  • HRRR grid bounds: lat rows 0-1056, lon cols 0-1792
  • Input state weather variables (99): u10m, v10m, t2m, msl, u1hl, u2hl, u3hl, u4hl, u5hl, u6hl, u7hl, u8hl, u9hl, u10hl, u11hl, u13hl, u15hl, u20hl, u25hl, u30hl, v1hl, v2hl, v3hl, v4hl, v5hl, v6hl, v7hl, v8hl, v9hl, v10hl, v11hl, v13hl, v15hl, v20hl, v25hl, v30hl, t1hl, t2hl, t3hl, t4hl, t5hl, t6hl, t7hl, t8hl, t9hl, t10hl, t11hl, t13hl, t15hl, t20hl, t25hl, t30hl, q1hl, q2hl, q3hl, q4hl, q5hl, q6hl, q7hl, q8hl, q9hl, q10hl, q11hl, q13hl, q15hl, q20hl, q25hl, q30hl, Z1hl, Z2hl, Z3hl, Z4hl, Z5hl, Z6hl, Z7hl, Z8hl, Z9hl, Z10hl, Z11hl, Z13hl, Z15hl, Z20hl, Z25hl, Z30hl, p1hl, p2hl, p3hl, p4hl, p5hl, p6hl, p7hl, p8hl, p9hl, p10hl, p11hl, p13hl, p15hl, p20hl, refc
  • Conditioning weather variables (26): u10m, v10m, t2m, tcwv, sp, msl, u1000, u850, u500, u250, v1000, v850, v500, v250, z1000, z850, z500, z250, t1000, t850, t500, t250, q1000, q850, q500, q250
  • Invariants (2): lsm (land-sea mask), orography

For lexicon information, review the HRRR Lexicon at Earth2Studio. Variables marked with hl refer to natural/hybrid model levels.

Output

Output Type(s): Tensor (99 surface and model level variables)
Output Format: PyTorch Tensors
Output Parameters: Four Dimensional (4D) (batch, variable, latitude, longitude)
Other Properties Related to Output:

  • Output grid: HRRR Lambert Conformal, 1056 x 1792 at 3km
  • Output state weather variables: same 99 variables as input

Software Integration

Supported Hardware Microarchitecture Compatibility:

  • NVIDIA Ampere (A100)
  • NVIDIA Hopper (H100)

Supported Operating System(s):

  • Linux

Dependencies:

Model Version

Model Version: v0 (regression-only, preliminary training)

Training & Compute

Training Dataset

Link: HRRR

Data Collection Method: Automatic/Sensors

Properties: HRRR data for the date range of 2023/01/01 to 2023/01/11 (264 hourly samples). The HRRR is a NOAA real-time 3-km resolution, hourly updated, cloud-resolving, convection-allowing atmospheric model, initialized by 3km grids with 3km radar assimilation. Data covers the full CONUS domain (1056 x 1792 grid points).

Link: ERA5 (via ARCO)

Data Collection Method: Automatic/Sensors

Properties: ERA5 data for the date range of 2023/01/01 to 2023/01/11, interpolated to the HRRR grid. ERA5 provides hourly estimates of various atmospheric, land, and oceanic climate variables. The data covers the Earth on a 25km grid and resolves the atmosphere at 137 levels.

Training Configuration

Parameter Value
Optimizer Adam (fused)
Learning rate 4e-4
LR rampup steps 1,000
Total steps 16,000
Effective batch size 4 (gradient accumulation)
Batch size per GPU 1
Loss MSE (regression)
Gradient clipping 1.0
Precision BF16 (AMP) with TF32

Compute

Resource Value
GPU 1x NVIDIA H100 80GB
Peak GPU memory ~29 GiB
Training speed ~4.5 s/step (with grad accum)
Training time ~21 hours (16,000 steps)
Final train loss 0.0143
Final val loss 0.0125

Inference

Test Hardware:

  • H100 (80 GB)

Usage with Earth2Studio

import numpy as np
import torch
from huggingface_hub import hf_hub_download
from earth2studio.data import HRRR, GFS_FX
from earth2studio.io import ZarrBackend
from earth2studio.models.px import StormCast
from earth2studio.models.px.stormcast import (
    CONDITIONING_VARIABLES, INVARIANTS, VARIABLES,
)
from physicsnemo.core import Module as PhysicsNemoModule
import earth2studio.run as run

REPO_ID = "kashif/stormcast-regression-conus-v0"


class RegressionOnlyStormCast(StormCast):
    """StormCast wrapper using only the regression model (no diffusion)."""

    @torch.inference_mode()
    def _forward(self, x, conditioning):
        if "conditioning_means" in self._buffers:
            conditioning = conditioning - self.conditioning_means
        if "conditioning_stds" in self._buffers:
            conditioning = conditioning / self.conditioning_stds
        x = (x - self.means) / self.stds
        invariant_tensor = self.invariants.repeat(x.shape[0], 1, 1, 1)
        concats = torch.cat((x, conditioning, invariant_tensor), dim=1)
        out = self.regression_model(concats)
        out = out * self.stds + self.means
        return out


# Download and load checkpoint
ckpt_path = hf_hub_download(REPO_ID, "StormCastUNet.0.16000.mdlus")
regression = PhysicsNemoModule.from_checkpoint(ckpt_path)
diffusion = torch.nn.Identity()

# Download and load normalization stats
means = torch.from_numpy(
    np.load(hf_hub_download(REPO_ID, "means.npy"))[None, :, None, None]
)
stds = torch.from_numpy(
    np.load(hf_hub_download(REPO_ID, "stds.npy"))[None, :, None, None]
)
conditioning_means = torch.from_numpy(
    np.load(hf_hub_download(REPO_ID, "conditioning_means.npy"))[None, :, None, None]
)
conditioning_stds = torch.from_numpy(
    np.load(hf_hub_download(REPO_ID, "conditioning_stds.npy"))[None, :, None, None]
)

# Download and load invariants
import xarray as xr
from huggingface_hub import snapshot_download
inv_path = snapshot_download(REPO_ID, allow_patterns="invariants.zarr/**")
inv = xr.open_zarr(f"{inv_path}/invariants.zarr", consolidated=False)
invariants = torch.from_numpy(
    inv["HRRR_invariants"].sel(channel=["lsm", "orography"]).values[None]
)

# Build model for full CONUS
model = RegressionOnlyStormCast(
    regression_model=regression,
    diffusion_model=diffusion,
    means=means,
    stds=stds,
    invariants=invariants,
    hrrr_lat_lim=(0, 1056),
    hrrr_lon_lim=(0, 1792),
    variables=np.array(VARIABLES),
    conditioning_means=conditioning_means,
    conditioning_stds=conditioning_stds,
    conditioning_variables=np.array(CONDITIONING_VARIABLES),
    conditioning_data_source=GFS_FX(),
)

# Run 4-hour forecast
io = ZarrBackend()
io = run.deterministic(["2026-02-12"], 4, model, HRRR(), io)

# Plot composite reflectivity
import cartopy
import cartopy.crs as ccrs
import matplotlib.pyplot as plt

projection = ccrs.LambertConformal(
    central_longitude=262.5, central_latitude=38.5,
    standard_parallels=(38.5, 38.5),
    globe=ccrs.Globe(semimajor_axis=6371229, semiminor_axis=6371229),
)
fig, ax = plt.subplots(subplot_kw={"projection": projection}, figsize=(12, 7))
im = ax.pcolormesh(
    model.lon, model.lat, io["refc"][0, 4],
    transform=ccrs.PlateCarree(), cmap="turbo", vmin=-10, vmax=60,
)
ax.add_feature(
    cartopy.feature.STATES.with_scale("50m"),
    linewidth=0.5, edgecolor="black", zorder=2,
)
ax.coastlines()
ax.gridlines()
ax.set_title("StormCast CONUS - Composite Reflectivity +4h")
fig.colorbar(im, ax=ax, shrink=0.7, label="dBZ")
plt.savefig("refc_prediction.png", dpi=150, bbox_inches="tight")

Limitations

  • Preliminary training: Trained on only 11 days of January 2023 data (264 hourly samples). The original StormCast V1 used July 2018 - December 2021 (~3.5 years).
  • Regression-only: No diffusion model, so predictions are smooth/blurry. The full StormCast pipeline requires the diffusion stage for realistic storm structure.
  • Reflectivity: Composite reflectivity (refc) predictions are near-uniform because MSE regression on a sparse variable converges to the mean. The diffusion model is essential for realistic refc.
  • Reduced architecture: Model channels = 64 (vs 128 in original) to fit full CONUS on a single H100.

Ethical Considerations

NVIDIA believes Trustworthy AI is a shared responsibility. When downloaded or used in accordance with the terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case.

Citation

@article{pathak2024stormcast,
  title={Kilometer-Scale Convection Allowing Model Emulation using Generative Diffusion Modeling},
  author={Pathak, Jaideep and others},
  journal={arXiv preprint arXiv:2408.10958},
  year={2024}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for kashif/stormcast-regression-conus-v0