irene / tests /test_inference.py
franch's picture
Add source code and examples
df27dfb verified
raw
history blame contribute delete
817 Bytes
import numpy as np
import xarray as xr
from convgru_ensemble.lightning_model import RadarLightningModel
def test_inference_on_sample_data():
"""End-to-end inference on the sample NetCDF with a freshly initialized model."""
ds = xr.open_dataset("examples/sample_data.nc")
rain = ds["RR"].values # (54, 1400, 1200)
# Use 6 past frames, full spatial extent
past = rain[:6].astype(np.float32)
_, H, W = past.shape
model = RadarLightningModel(
input_channels=1,
num_blocks=3,
forecast_steps=4,
ensemble_size=1,
noisy_decoder=False,
)
preds = model.predict(past, forecast_steps=4, ensemble_size=1)
assert preds.shape == (1, 4, H, W)
assert np.isfinite(preds).all()
assert preds.dtype == np.float64 or preds.dtype == np.float32