irene / tests /test_lightning_model.py
franch's picture
Add source code and examples
df27dfb verified
import numpy as np
import torch
from convgru_ensemble.lightning_model import RadarLightningModel
def test_predict_handles_unpadded_inputs():
model = RadarLightningModel(
input_channels=1,
num_blocks=1,
forecast_steps=2,
ensemble_size=1,
noisy_decoder=False,
)
past = np.zeros((4, 8, 8), dtype=np.float32)
preds = model.predict(past, forecast_steps=2, ensemble_size=1)
assert preds.shape == (1, 2, 8, 8)
assert np.isfinite(preds).all()
def test_from_checkpoint_delegates_to_lightning_loader(monkeypatch):
captured = {}
def fake_loader(cls, checkpoint_path, map_location=None, strict=None, weights_only=None):
captured["checkpoint_path"] = checkpoint_path
captured["map_location"] = map_location
captured["strict"] = strict
captured["weights_only"] = weights_only
return "loaded-model"
monkeypatch.setattr(RadarLightningModel, "load_from_checkpoint", classmethod(fake_loader))
loaded = RadarLightningModel.from_checkpoint("/tmp/model.ckpt", device="cpu")
assert loaded == "loaded-model"
assert captured["checkpoint_path"] == "/tmp/model.ckpt"
assert isinstance(captured["map_location"], torch.device)
assert captured["map_location"].type == "cpu"
assert captured["strict"] is True