File size: 969 Bytes
a5be142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from pvnet.models.ensemble import Ensemble


def test_model_init(multimodal_model):
    ensemble_model = Ensemble(
        model_list=[multimodal_model] * 3,
        weights=None,
    )

    ensemble_model = Ensemble(
        model_list=[multimodal_model] * 3,
        weights=[1, 2, 3],
    )


def test_model_forward(multimodal_model, sample_batch):
    ensemble_model = Ensemble(
        model_list=[multimodal_model] * 3,
    )

    y = ensemble_model(sample_batch)

    # check output is the correct shape
    # batch size=2, forecast_len=15
    assert tuple(y.shape) == (2, 16), y.shape


def test_quantile_model_forward(multimodal_quantile_model, sample_batch):
    ensemble_model = Ensemble(
        model_list=[multimodal_quantile_model] * 3,
    )

    y_quantiles = ensemble_model(sample_batch)

    # check output is the correct shape
    # batch size=2, forecast_len=15, num_quantiles=3
    assert tuple(y_quantiles.shape) == (2, 16, 3), y_quantiles.shape