StormCast CONUS Regression Model (Full Domain)
Description
StormCast CONUS is a regression UNet model for high-resolution (3km) weather prediction over the full Continental United States (CONUS) domain. It is trained as the first stage of the StormCast regression-diffusion framework, which autoregressively predicts 99 state variables at km scale using a 1-hour time step, with dense vertical resolution in the atmospheric boundary layer.
This model extends the original StormCast V1 to the full CONUS domain (1056 x 1792 grid) rather than the smaller central US bounding box (512 x 640) used in the original paper.
This checkpoint contains the regression model only. The full StormCast pipeline additionally uses an EDM diffusion model to add fine-grained stochastic structure (storm cells, precipitation bands).
For training recipes see NVIDIA PhysicsNeMo, for inference see NVIDIA Earth2Studio.
This model is for research and development only.
License/Terms of Use
This model is distributed under the Apache 2.0 license.
Reference(s)
Kilometer-Scale Convection Allowing Model Emulation using Generative Diffusion Modeling
Model Architecture
Architecture Type: StormCast uses a UNet architecture in a regression-diffusion generative model framework.
Network Architecture: UNet
| Parameter | Value |
|---|---|
| Model channels | 64 |
| Channel multiplier | [1, 2, 4, 4] |
| Blocks per level | 2 |
| Gradient checkpointing | Level 2 |
| Spatial positional embedding | Yes |
| Precision | BF16 (AMP) |
Checkpoints
| Checkpoint | Training Data | Steps | Train Loss | Val Loss | Description |
|---|---|---|---|---|---|
StormCastUNet.0.16000.mdlus |
2023 HRRR+ERA5 (Jan 1-11) | 16,000 | 0.0143 | 0.0125 | Initial v0 training |
StormCastUNet.0.32000.mdlus |
2025 HRRR+GFS (full year) | 32,000 | 0.0128 | 0.0104 | Fine-tuned on 2025 data |
The recommended checkpoint is StormCastUNet.0.32000.mdlus (fine-tuned on full year 2025).
Input
Input Type(s):
- Tensor (127 channels: 99 state variables + 26 conditioning variables + 2 invariants)
- DateTime (NumPy Array)
Input Format(s): PyTorch Tensor / NumPy array
Input Parameters:
- Four Dimensional (4D) (batch, variable, latitude, longitude)
- Input DateTime (1D)
Other Properties Related to Input:
- Input grid: HRRR Lambert Conformal projection, full CONUS domain (1056 x 1792 at 3km)
- HRRR grid bounds: lat rows 0-1056, lon cols 0-1792
- Input state weather variables (99):
u10m,v10m,t2m,msl,u1hl,u2hl,u3hl,u4hl,u5hl,u6hl,u7hl,u8hl,u9hl,u10hl,u11hl,u13hl,u15hl,u20hl,u25hl,u30hl,v1hl,v2hl,v3hl,v4hl,v5hl,v6hl,v7hl,v8hl,v9hl,v10hl,v11hl,v13hl,v15hl,v20hl,v25hl,v30hl,t1hl,t2hl,t3hl,t4hl,t5hl,t6hl,t7hl,t8hl,t9hl,t10hl,t11hl,t13hl,t15hl,t20hl,t25hl,t30hl,q1hl,q2hl,q3hl,q4hl,q5hl,q6hl,q7hl,q8hl,q9hl,q10hl,q11hl,q13hl,q15hl,q20hl,q25hl,q30hl,Z1hl,Z2hl,Z3hl,Z4hl,Z5hl,Z6hl,Z7hl,Z8hl,Z9hl,Z10hl,Z11hl,Z13hl,Z15hl,Z20hl,Z25hl,Z30hl,p1hl,p2hl,p3hl,p4hl,p5hl,p6hl,p7hl,p8hl,p9hl,p10hl,p11hl,p13hl,p15hl,p20hl,refc - Conditioning weather variables (26):
u10m,v10m,t2m,tcwv,sp,msl,u1000,u850,u500,u250,v1000,v850,v500,v250,z1000,z850,z500,z250,t1000,t850,t500,t250,q1000,q850,q500,q250 - Invariants (2):
lsm(land-sea mask),orography
For lexicon information, review the HRRR Lexicon at Earth2Studio. Variables marked with hl refer to natural/hybrid model levels.
Output
Output Type(s): Tensor (99 surface and model level variables)
Output Format: PyTorch Tensors
Output Parameters: Four Dimensional (4D) (batch, variable, latitude, longitude)
Other Properties Related to Output:
- Output grid: HRRR Lambert Conformal, 1056 x 1792 at 3km
- Output state weather variables: same 99 variables as input
Software Integration
Supported Hardware Microarchitecture Compatibility:
- NVIDIA Ampere (A100)
- NVIDIA Hopper (H100)
Supported Operating System(s):
- Linux
Dependencies:
- NVIDIA PhysicsNeMo (model loading)
- NVIDIA Earth2Studio (inference pipeline)
Model Version
Model Version: v0
Training & Compute
Training Dataset
Stage 1: Initial Training (v0)
Link: HRRR
Data Collection Method: Automatic/Sensors
Properties: HRRR data for the date range of 2023/01/01 to 2023/01/11 (264 hourly samples). The HRRR is a NOAA real-time 3-km resolution, hourly updated, cloud-resolving, convection-allowing atmospheric model, initialized by 3km grids with 3km radar assimilation. Data covers the full CONUS domain (1056 x 1792 grid points).
Link: ERA5 (via ARCO)
Data Collection Method: Automatic/Sensors
Properties: ERA5 data for the date range of 2023/01/01 to 2023/01/11, interpolated to the HRRR grid. ERA5 provides hourly estimates of various atmospheric, land, and oceanic climate variables.
Stage 2: Fine-tuning on 2025 Data
Link: HRRR
Data Collection Method: Automatic/Sensors
Properties: HRRR data for the full year 2025 (8,760 hourly samples, ~3.3 TB). Full CONUS domain (1056 x 1792).
Conditioning: GFS (GFS_FX) for 2025, interpolated to the HRRR grid (26 conditioning variables at pressure levels).
Invariants: Land-sea mask and orography from ARCO ERA5, interpolated to the HRRR grid.
Data fetched and prepared using NVIDIA Earth2Studio data APIs (HRRR, GFS_FX, ARCO sources).
Training Configuration
Stage 1: Initial Training
| Parameter | Value |
|---|---|
| Optimizer | Adam (fused) |
| Learning rate | 4e-4 |
| LR rampup steps | 1,000 |
| Total steps | 16,000 |
| Effective batch size | 4 (gradient accumulation) |
| Batch size per GPU | 1 |
| Loss | MSE (regression) |
| Gradient clipping | 1.0 |
| Precision | BF16 (AMP) with TF32 |
Stage 2: Fine-tuning on 2025
| Parameter | Value |
|---|---|
| Optimizer | Adam (fused) |
| Learning rate | 2e-4 |
| LR rampup steps | 500 |
| Total steps | 32,000 (resumed from v0 weights at step 0) |
| Effective batch size | 4 (gradient accumulation) |
| Batch size per GPU | 1 |
| Loss | MSE (regression) |
| Gradient clipping | 1.0 |
| Precision | BF16 (AMP) with TF32 |
Compute
| Resource | Value |
|---|---|
| GPU | 1x NVIDIA H100 80GB |
| Peak GPU memory | ~29 GiB |
| Training speed | ~5.0 s/step |
| Stage 1 training time | ~21 hours (16,000 steps) |
| Stage 2 training time | ~26 hours (32,000 steps) |
| Final train loss | 0.0128 |
| Final val loss | 0.0104 |
Inference
Test Hardware:
- H100 (80 GB)
Usage with Earth2Studio
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from earth2studio.data import HRRR, GFS_FX
from earth2studio.io import ZarrBackend
from earth2studio.models.px import StormCast
from earth2studio.models.px.stormcast import (
CONDITIONING_VARIABLES, INVARIANTS, VARIABLES,
)
from physicsnemo.core import Module as PhysicsNemoModule
import earth2studio.run as run
REPO_ID = "kashif/stormcast-regression-conus-v0"
class RegressionOnlyStormCast(StormCast):
"""StormCast wrapper using only the regression model (no diffusion)."""
@torch.inference_mode()
def _forward(self, x, conditioning):
if "conditioning_means" in self._buffers:
conditioning = conditioning - self.conditioning_means
if "conditioning_stds" in self._buffers:
conditioning = conditioning / self.conditioning_stds
x = (x - self.means) / self.stds
invariant_tensor = self.invariants.repeat(x.shape[0], 1, 1, 1)
concats = torch.cat((x, conditioning, invariant_tensor), dim=1)
out = self.regression_model(concats)
out = out * self.stds + self.means
return out
# Download and load checkpoint (use 32000 for 2025-finetuned, 16000 for v0)
ckpt_path = hf_hub_download(REPO_ID, "StormCastUNet.0.32000.mdlus")
regression = PhysicsNemoModule.from_checkpoint(ckpt_path)
diffusion = torch.nn.Identity()
# Download and load normalization stats
means = torch.from_numpy(
np.load(hf_hub_download(REPO_ID, "means.npy"))[None, :, None, None]
)
stds = torch.from_numpy(
np.load(hf_hub_download(REPO_ID, "stds.npy"))[None, :, None, None]
)
conditioning_means = torch.from_numpy(
np.load(hf_hub_download(REPO_ID, "conditioning_means.npy"))[None, :, None, None]
)
conditioning_stds = torch.from_numpy(
np.load(hf_hub_download(REPO_ID, "conditioning_stds.npy"))[None, :, None, None]
)
# Download and load invariants
import xarray as xr
from huggingface_hub import snapshot_download
inv_path = snapshot_download(REPO_ID, allow_patterns="invariants.zarr/**")
inv = xr.open_zarr(f"{inv_path}/invariants.zarr", consolidated=False)
invariants = torch.from_numpy(
inv["HRRR_invariants"].sel(channel=["lsm", "orography"]).values[None]
)
# Build model for full CONUS
model = RegressionOnlyStormCast(
regression_model=regression,
diffusion_model=diffusion,
means=means,
stds=stds,
invariants=invariants,
hrrr_lat_lim=(0, 1056),
hrrr_lon_lim=(0, 1792),
variables=np.array(VARIABLES),
conditioning_means=conditioning_means,
conditioning_stds=conditioning_stds,
conditioning_variables=np.array(CONDITIONING_VARIABLES),
conditioning_data_source=GFS_FX(),
)
# Run 4-hour forecast (use any recent date with HRRR/GFS data available)
io = ZarrBackend()
io = run.deterministic(["2026-03-13"], 4, model, HRRR(), io)
# Plot composite reflectivity
import cartopy
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
projection = ccrs.LambertConformal(
central_longitude=262.5, central_latitude=38.5,
standard_parallels=(38.5, 38.5),
globe=ccrs.Globe(semimajor_axis=6371229, semiminor_axis=6371229),
)
fig, ax = plt.subplots(subplot_kw={"projection": projection}, figsize=(12, 7))
im = ax.pcolormesh(
model.lon, model.lat, io["refc"][0, 4],
transform=ccrs.PlateCarree(), cmap="turbo", vmin=-10, vmax=60,
)
ax.add_feature(
cartopy.feature.STATES.with_scale("50m"),
linewidth=0.5, edgecolor="black", zorder=2,
)
ax.coastlines()
ax.gridlines()
ax.set_title("StormCast CONUS - Composite Reflectivity +4h")
fig.colorbar(im, ax=ax, shrink=0.7, label="dBZ")
plt.savefig("refc_prediction.png", dpi=150, bbox_inches="tight")
Example Inference Results
4-hour autoregressive forecast initialized from HRRR analysis on 2026-03-13 00Z, with GFS conditioning. Showing 2m temperature, 10m U-wind, composite reflectivity, and mean sea level pressure.
Initial conditions (+0h):
+2h forecast:
+4h forecast:
Limitations
- Regression-only: No diffusion model, so predictions are spatially smooth. The full StormCast pipeline requires the EDM diffusion stage for realistic fine-grained storm structure (convective cells, precipitation bands).
- Reflectivity: Composite reflectivity (refc) captures large-scale precipitation patterns (frontal bands, lake-effect regions) but peaks around 24 dBZ — well below real convective storms (>50 dBZ). The diffusion model is essential for realistic storm-cell intensities.
- Reduced architecture: Model channels = 64 (vs 128 in original StormCast V1) to fit full CONUS domain on a single H100 80GB GPU.
- Single-year fine-tuning: Stage 2 trained on full year 2025 (8,760 hourly samples) covering all seasons, but still less diverse than the original V1's 3.5 years (2018-2021). Rare or extreme weather events may be underrepresented.
Ethical Considerations
NVIDIA believes Trustworthy AI is a shared responsibility. When downloaded or used in accordance with the terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case.
Citation
@article{pathak2024stormcast,
title={Kilometer-Scale Convection Allowing Model Emulation using Generative Diffusion Modeling},
author={Pathak, Jaideep and others},
journal={arXiv preprint arXiv:2408.10958},
year={2024}
}


