File size: 1,330 Bytes
df27dfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
import torch

from convgru_ensemble.lightning_model import RadarLightningModel


def test_predict_handles_unpadded_inputs():
    model = RadarLightningModel(
        input_channels=1,
        num_blocks=1,
        forecast_steps=2,
        ensemble_size=1,
        noisy_decoder=False,
    )
    past = np.zeros((4, 8, 8), dtype=np.float32)

    preds = model.predict(past, forecast_steps=2, ensemble_size=1)

    assert preds.shape == (1, 2, 8, 8)
    assert np.isfinite(preds).all()


def test_from_checkpoint_delegates_to_lightning_loader(monkeypatch):
    captured = {}

    def fake_loader(cls, checkpoint_path, map_location=None, strict=None, weights_only=None):
        captured["checkpoint_path"] = checkpoint_path
        captured["map_location"] = map_location
        captured["strict"] = strict
        captured["weights_only"] = weights_only
        return "loaded-model"

    monkeypatch.setattr(RadarLightningModel, "load_from_checkpoint", classmethod(fake_loader))

    loaded = RadarLightningModel.from_checkpoint("/tmp/model.ckpt", device="cpu")

    assert loaded == "loaded-model"
    assert captured["checkpoint_path"] == "/tmp/model.ckpt"
    assert isinstance(captured["map_location"], torch.device)
    assert captured["map_location"].type == "cpu"
    assert captured["strict"] is True