|
|
"""Model which uses mutliple prediction heads""" |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from pvnet.models.base_model import BaseModel |
|
|
|
|
|
|
|
|
class Ensemble(BaseModel): |
|
|
"""Ensemble of PVNet models""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_list: list[BaseModel], |
|
|
weights: Optional[list[float]] = None, |
|
|
): |
|
|
"""Ensemble of PVNet models |
|
|
|
|
|
Args: |
|
|
model_list: A list of PVNet models to ensemble |
|
|
weights: A list of weighting to apply to each model. If None, the models are weighted |
|
|
equally. |
|
|
""" |
|
|
|
|
|
|
|
|
output_quantiles = [] |
|
|
history_minutes = [] |
|
|
forecast_minutes = [] |
|
|
target_key = [] |
|
|
interval_minutes = [] |
|
|
|
|
|
|
|
|
for model in model_list: |
|
|
output_quantiles.append(model.output_quantiles) |
|
|
history_minutes.append(model.history_minutes) |
|
|
forecast_minutes.append(model.forecast_minutes) |
|
|
target_key.append(model._target_key) |
|
|
interval_minutes.append(model.interval_minutes) |
|
|
|
|
|
|
|
|
for param_list in [ |
|
|
output_quantiles, |
|
|
history_minutes, |
|
|
forecast_minutes, |
|
|
target_key, |
|
|
interval_minutes, |
|
|
]: |
|
|
assert all([p == param_list[0] for p in param_list]), param_list |
|
|
|
|
|
super().__init__( |
|
|
history_minutes=history_minutes[0], |
|
|
forecast_minutes=forecast_minutes[0], |
|
|
optimizer=None, |
|
|
output_quantiles=output_quantiles[0], |
|
|
target_key=target_key[0], |
|
|
interval_minutes=interval_minutes[0], |
|
|
) |
|
|
|
|
|
self.model_list = nn.ModuleList(model_list) |
|
|
|
|
|
if weights is None: |
|
|
weights = torch.ones(len(model_list)) / len(model_list) |
|
|
else: |
|
|
assert len(weights) == len(model_list) |
|
|
weights = torch.Tensor(weights) / sum(weights) |
|
|
self.weights = nn.Parameter(weights, requires_grad=False) |
|
|
|
|
|
def forward(self, batch): |
|
|
"""Run the model forward""" |
|
|
y_hat = 0 |
|
|
for weight, model in zip(self.weights, self.model_list): |
|
|
y_hat = model(batch) * weight + y_hat |
|
|
return y_hat |
|
|
|