import numpy as np import torch from convgru_ensemble.lightning_model import RadarLightningModel def test_predict_handles_unpadded_inputs(): model = RadarLightningModel( input_channels=1, num_blocks=1, forecast_steps=2, ensemble_size=1, noisy_decoder=False, ) past = np.zeros((4, 8, 8), dtype=np.float32) preds = model.predict(past, forecast_steps=2, ensemble_size=1) assert preds.shape == (1, 2, 8, 8) assert np.isfinite(preds).all() def test_from_checkpoint_delegates_to_lightning_loader(monkeypatch): captured = {} def fake_loader(cls, checkpoint_path, map_location=None, strict=None, weights_only=None): captured["checkpoint_path"] = checkpoint_path captured["map_location"] = map_location captured["strict"] = strict captured["weights_only"] = weights_only return "loaded-model" monkeypatch.setattr(RadarLightningModel, "load_from_checkpoint", classmethod(fake_loader)) loaded = RadarLightningModel.from_checkpoint("/tmp/model.ckpt", device="cpu") assert loaded == "loaded-model" assert captured["checkpoint_path"] == "/tmp/model.ckpt" assert isinstance(captured["map_location"], torch.device) assert captured["map_location"].type == "cpu" assert captured["strict"] is True