File size: 2,335 Bytes
cbe6208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Model which uses mutliple prediction heads"""
from typing import Optional

import torch
from torch import nn

from pvnet.models.base_model import BaseModel


class Ensemble(BaseModel):
    """Ensemble of PVNet models"""

    def __init__(
        self,
        model_list: list[BaseModel],
        weights: Optional[list[float]] = None,
    ):
        """Ensemble of PVNet models

        Args:
            model_list: A list of PVNet models to ensemble
            weights: A list of weighting to apply to each model. If None, the models are weighted
                equally.
        """

        # Surface check all the models are compatible
        output_quantiles = []
        history_minutes = []
        forecast_minutes = []
        target_key = []
        interval_minutes = []

        # Get some model properties from each model
        for model in model_list:
            output_quantiles.append(model.output_quantiles)
            history_minutes.append(model.history_minutes)
            forecast_minutes.append(model.forecast_minutes)
            target_key.append(model._target_key)
            interval_minutes.append(model.interval_minutes)

        # Check these properties are all the same
        for param_list in [
            output_quantiles,
            history_minutes,
            forecast_minutes,
            target_key,
            interval_minutes,
        ]:
            assert all([p == param_list[0] for p in param_list]), param_list

        super().__init__(
            history_minutes=history_minutes[0],
            forecast_minutes=forecast_minutes[0],
            optimizer=None,
            output_quantiles=output_quantiles[0],
            target_key=target_key[0],
            interval_minutes=interval_minutes[0],
        )

        self.model_list = nn.ModuleList(model_list)

        if weights is None:
            weights = torch.ones(len(model_list)) / len(model_list)
        else:
            assert len(weights) == len(model_list)
            weights = torch.Tensor(weights) / sum(weights)
        self.weights = nn.Parameter(weights, requires_grad=False)

    def forward(self, batch):
        """Run the model forward"""
        y_hat = 0
        for weight, model in zip(self.weights, self.model_list):
            y_hat = model(batch) * weight + y_hat
        return y_hat