| import numpy as np | |
| import xarray as xr | |
| from convgru_ensemble.lightning_model import RadarLightningModel | |
| def test_inference_on_sample_data(): | |
| """End-to-end inference on the sample NetCDF with a freshly initialized model.""" | |
| ds = xr.open_dataset("examples/sample_data.nc") | |
| rain = ds["RR"].values # (54, 1400, 1200) | |
| # Use 6 past frames, full spatial extent | |
| past = rain[:6].astype(np.float32) | |
| _, H, W = past.shape | |
| model = RadarLightningModel( | |
| input_channels=1, | |
| num_blocks=3, | |
| forecast_steps=4, | |
| ensemble_size=1, | |
| noisy_decoder=False, | |
| ) | |
| preds = model.predict(past, forecast_steps=4, ensemble_size=1) | |
| assert preds.shape == (1, 4, H, W) | |
| assert np.isfinite(preds).all() | |
| assert preds.dtype == np.float64 or preds.dtype == np.float32 | |