Commit ·
7bffb2f
1
Parent(s): b74423e
Delete pvnet
Browse files- pvnet/__init__.py +0 -2
- pvnet/callbacks.py +0 -129
- pvnet/data/__init__.py +0 -3
- pvnet/data/base_datamodule.py +0 -118
- pvnet/data/site_datamodule.py +0 -53
- pvnet/data/uk_regional_datamodule.py +0 -54
- pvnet/load_model.py +0 -71
- pvnet/models/__init__.py +0 -1
- pvnet/models/base_model.py +0 -973
- pvnet/models/baseline/__init__.py +0 -1
- pvnet/models/baseline/last_value.py +0 -42
- pvnet/models/baseline/readme.md +0 -5
- pvnet/models/baseline/single_value.py +0 -36
- pvnet/models/ensemble.py +0 -74
- pvnet/models/model_cards/pv_india_model_card_template.md +0 -56
- pvnet/models/model_cards/pv_uk_regional_model_card_template.md +0 -59
- pvnet/models/model_cards/wind_india_model_card_template.md +0 -56
- pvnet/models/multimodal/__init__.py +0 -1
- pvnet/models/multimodal/basic_blocks.py +0 -104
- pvnet/models/multimodal/encoders/__init__.py +0 -1
- pvnet/models/multimodal/encoders/basic_blocks.py +0 -217
- pvnet/models/multimodal/encoders/encoders2d.py +0 -413
- pvnet/models/multimodal/encoders/encoders3d.py +0 -402
- pvnet/models/multimodal/encoders/encodersRNN.py +0 -141
- pvnet/models/multimodal/linear_networks/__init__.py +0 -1
- pvnet/models/multimodal/linear_networks/basic_blocks.py +0 -121
- pvnet/models/multimodal/linear_networks/networks.py +0 -332
- pvnet/models/multimodal/multimodal.py +0 -417
- pvnet/models/multimodal/readme.md +0 -11
- pvnet/models/multimodal/site_encoders/__init__.py +0 -1
- pvnet/models/multimodal/site_encoders/basic_blocks.py +0 -35
- pvnet/models/multimodal/site_encoders/encoders.py +0 -284
- pvnet/models/multimodal/unimodal_teacher.py +0 -447
- pvnet/models/utils.py +0 -123
- pvnet/optimizers.py +0 -200
- pvnet/training.py +0 -183
- pvnet/utils.py +0 -321
pvnet/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
"""PVNet"""
|
| 2 |
-
__version__ = "4.1.18"
|
|
|
|
|
|
|
|
|
pvnet/callbacks.py
DELETED
|
@@ -1,129 +0,0 @@
|
|
| 1 |
-
"""Custom callbacks
|
| 2 |
-
"""
|
| 3 |
-
from lightning.pytorch import Trainer
|
| 4 |
-
from lightning.pytorch.callbacks import BaseFinetuning, EarlyStopping, LearningRateFinder
|
| 5 |
-
from lightning.pytorch.trainer.states import TrainerFn
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class PhaseEarlyStopping(EarlyStopping):
|
| 9 |
-
"""Monitor a validation metric and stop training when it stops improving.
|
| 10 |
-
|
| 11 |
-
Only functions in a specific phase of training.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
training_phase = None
|
| 15 |
-
|
| 16 |
-
def switch_phase(self, phase: str):
|
| 17 |
-
"""Switch phase of callback"""
|
| 18 |
-
if phase == self.training_phase:
|
| 19 |
-
self.activate()
|
| 20 |
-
else:
|
| 21 |
-
self.deactivate()
|
| 22 |
-
|
| 23 |
-
def deactivate(self):
|
| 24 |
-
"""Deactivate callback"""
|
| 25 |
-
self.active = False
|
| 26 |
-
|
| 27 |
-
def activate(self):
|
| 28 |
-
"""Activate callback"""
|
| 29 |
-
self.active = True
|
| 30 |
-
|
| 31 |
-
def _should_skip_check(self, trainer: Trainer) -> bool:
|
| 32 |
-
return (
|
| 33 |
-
(trainer.state.fn != TrainerFn.FITTING) or (trainer.sanity_checking) or not self.active
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
class PretrainEarlyStopping(EarlyStopping):
|
| 38 |
-
"""Monitor a validation metric and stop training when it stops improving.
|
| 39 |
-
|
| 40 |
-
Only functions in the 'pretrain' phase of training.
|
| 41 |
-
"""
|
| 42 |
-
|
| 43 |
-
training_phase = "pretrain"
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class MainEarlyStopping(EarlyStopping):
|
| 47 |
-
"""Monitor a validation metric and stop training when it stops improving.
|
| 48 |
-
|
| 49 |
-
Only functions in the 'main' phase of training.
|
| 50 |
-
"""
|
| 51 |
-
|
| 52 |
-
training_phase = "main"
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
class PretrainFreeze(BaseFinetuning):
|
| 56 |
-
"""Freeze the satellite and NWP encoders during pretraining"""
|
| 57 |
-
|
| 58 |
-
training_phase = "pretrain"
|
| 59 |
-
|
| 60 |
-
def __init__(self):
|
| 61 |
-
"""Freeze the satellite and NWP encoders during pretraining"""
|
| 62 |
-
super().__init__()
|
| 63 |
-
|
| 64 |
-
def freeze_before_training(self, pl_module):
|
| 65 |
-
"""Freeze satellite and NWP encoders before training start"""
|
| 66 |
-
# freeze any module you want
|
| 67 |
-
modules = []
|
| 68 |
-
if pl_module.include_sat:
|
| 69 |
-
modules += [pl_module.sat_encoder]
|
| 70 |
-
if pl_module.include_nwp:
|
| 71 |
-
modules += [pl_module.nwp_encoder]
|
| 72 |
-
self.freeze(modules)
|
| 73 |
-
|
| 74 |
-
def finetune_function(self, pl_module, current_epoch, optimizer):
|
| 75 |
-
"""Unfreeze satellite and NWP encoders"""
|
| 76 |
-
if not self.active:
|
| 77 |
-
modules = []
|
| 78 |
-
if pl_module.include_sat:
|
| 79 |
-
modules += [pl_module.sat_encoder]
|
| 80 |
-
if pl_module.include_nwp:
|
| 81 |
-
modules += [pl_module.nwp_encoder]
|
| 82 |
-
self.unfreeze_and_add_param_group(
|
| 83 |
-
modules=modules,
|
| 84 |
-
optimizer=optimizer,
|
| 85 |
-
train_bn=True,
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
def switch_phase(self, phase: str):
|
| 89 |
-
"""Switch phase of callback"""
|
| 90 |
-
if phase == self.training_phase:
|
| 91 |
-
self.activate()
|
| 92 |
-
else:
|
| 93 |
-
self.deactivate()
|
| 94 |
-
|
| 95 |
-
def deactivate(self):
|
| 96 |
-
"""Deactivate callback"""
|
| 97 |
-
self.active = False
|
| 98 |
-
|
| 99 |
-
def activate(self):
|
| 100 |
-
"""Activate callback"""
|
| 101 |
-
self.active = True
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
class PhasedLearningRateFinder(LearningRateFinder):
|
| 105 |
-
"""Finds a learning rate at the start of each phase of learning"""
|
| 106 |
-
|
| 107 |
-
active = True
|
| 108 |
-
|
| 109 |
-
def on_fit_start(self, *args, **kwargs):
|
| 110 |
-
"""Do nothing"""
|
| 111 |
-
return
|
| 112 |
-
|
| 113 |
-
def on_train_epoch_start(self, trainer, pl_module):
|
| 114 |
-
"""Run learning rate finder on epoch start and then deactivate"""
|
| 115 |
-
if self.active:
|
| 116 |
-
self.lr_find(trainer, pl_module)
|
| 117 |
-
self.deactivate()
|
| 118 |
-
|
| 119 |
-
def switch_phase(self, phase: str):
|
| 120 |
-
"""Switch training phase"""
|
| 121 |
-
self.activate()
|
| 122 |
-
|
| 123 |
-
def deactivate(self):
|
| 124 |
-
"""Deactivate callback"""
|
| 125 |
-
self.active = False
|
| 126 |
-
|
| 127 |
-
def activate(self):
|
| 128 |
-
"""Activate callback"""
|
| 129 |
-
self.active = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/data/__init__.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
"""Data parts"""
|
| 2 |
-
from .site_datamodule import SiteDataModule
|
| 3 |
-
from .uk_regional_datamodule import DataModule
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/data/base_datamodule.py
DELETED
|
@@ -1,118 +0,0 @@
|
|
| 1 |
-
""" Data module for pytorch lightning """
|
| 2 |
-
|
| 3 |
-
from glob import glob
|
| 4 |
-
|
| 5 |
-
from lightning.pytorch import LightningDataModule
|
| 6 |
-
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
|
| 7 |
-
from ocf_data_sampler.torch_datasets.sample.base import (
|
| 8 |
-
NumpyBatch,
|
| 9 |
-
SampleBase,
|
| 10 |
-
TensorBatch,
|
| 11 |
-
batch_to_tensor,
|
| 12 |
-
)
|
| 13 |
-
from torch.utils.data import DataLoader, Dataset
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def collate_fn(samples: list[NumpyBatch]) -> TensorBatch:
|
| 17 |
-
"""Convert a list of NumpySample samples to a tensor batch"""
|
| 18 |
-
return batch_to_tensor(stack_np_samples_into_batch(samples))
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class PremadeSamplesDataset(Dataset):
|
| 22 |
-
"""Dataset to load samples from
|
| 23 |
-
|
| 24 |
-
Args:
|
| 25 |
-
sample_dir: Path to the directory of pre-saved samples.
|
| 26 |
-
sample_class: sample class type to use for save/load/to_numpy
|
| 27 |
-
"""
|
| 28 |
-
|
| 29 |
-
def __init__(self, sample_dir: str, sample_class: SampleBase):
|
| 30 |
-
"""Initialise PremadeSamplesDataset"""
|
| 31 |
-
self.sample_paths = glob(f"{sample_dir}/*")
|
| 32 |
-
self.sample_class = sample_class
|
| 33 |
-
|
| 34 |
-
def __len__(self):
|
| 35 |
-
return len(self.sample_paths)
|
| 36 |
-
|
| 37 |
-
def __getitem__(self, idx):
|
| 38 |
-
sample = self.sample_class.load(self.sample_paths[idx])
|
| 39 |
-
return sample.to_numpy()
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class BaseDataModule(LightningDataModule):
|
| 43 |
-
"""Base Datamodule for training pvnet and using pvnet pipeline in ocf-data-sampler."""
|
| 44 |
-
|
| 45 |
-
def __init__(
|
| 46 |
-
self,
|
| 47 |
-
configuration: str | None = None,
|
| 48 |
-
sample_dir: str | None = None,
|
| 49 |
-
batch_size: int = 16,
|
| 50 |
-
num_workers: int = 0,
|
| 51 |
-
prefetch_factor: int | None = None,
|
| 52 |
-
train_period: list[str | None] = [None, None],
|
| 53 |
-
val_period: list[str | None] = [None, None],
|
| 54 |
-
):
|
| 55 |
-
"""Base Datamodule for training pvnet architecture.
|
| 56 |
-
|
| 57 |
-
Can also be used with pre-made batches if `sample_dir` is set.
|
| 58 |
-
|
| 59 |
-
Args:
|
| 60 |
-
configuration: Path to ocf-data-sampler configuration file.
|
| 61 |
-
sample_dir: Path to the directory of pre-saved samples. Cannot be used together with
|
| 62 |
-
`configuration` or '[train/val]_period'.
|
| 63 |
-
batch_size: Batch size.
|
| 64 |
-
num_workers: Number of workers to use in multiprocess batch loading.
|
| 65 |
-
prefetch_factor: Number of data will be prefetched at the end of each worker process.
|
| 66 |
-
train_period: Date range filter for train dataloader.
|
| 67 |
-
val_period: Date range filter for val dataloader.
|
| 68 |
-
|
| 69 |
-
"""
|
| 70 |
-
super().__init__()
|
| 71 |
-
|
| 72 |
-
if not ((sample_dir is not None) ^ (configuration is not None)):
|
| 73 |
-
raise ValueError("Exactly one of `sample_dir` or `configuration` must be set.")
|
| 74 |
-
|
| 75 |
-
if sample_dir is not None:
|
| 76 |
-
if any([period != [None, None] for period in [train_period, val_period]]):
|
| 77 |
-
raise ValueError("Cannot set `(train/val)_period` with presaved samples")
|
| 78 |
-
|
| 79 |
-
self.configuration = configuration
|
| 80 |
-
self.sample_dir = sample_dir
|
| 81 |
-
self.train_period = train_period
|
| 82 |
-
self.val_period = val_period
|
| 83 |
-
|
| 84 |
-
self._common_dataloader_kwargs = dict(
|
| 85 |
-
batch_size=batch_size,
|
| 86 |
-
sampler=None,
|
| 87 |
-
batch_sampler=None,
|
| 88 |
-
num_workers=num_workers,
|
| 89 |
-
collate_fn=collate_fn,
|
| 90 |
-
pin_memory=False,
|
| 91 |
-
drop_last=False,
|
| 92 |
-
timeout=0,
|
| 93 |
-
worker_init_fn=None,
|
| 94 |
-
prefetch_factor=prefetch_factor,
|
| 95 |
-
persistent_workers=False,
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
|
| 99 |
-
raise NotImplementedError
|
| 100 |
-
|
| 101 |
-
def _get_premade_samples_dataset(self, subdir) -> Dataset:
|
| 102 |
-
raise NotImplementedError
|
| 103 |
-
|
| 104 |
-
def train_dataloader(self) -> DataLoader:
|
| 105 |
-
"""Construct train dataloader"""
|
| 106 |
-
if self.sample_dir is not None:
|
| 107 |
-
dataset = self._get_premade_samples_dataset("train")
|
| 108 |
-
else:
|
| 109 |
-
dataset = self._get_streamed_samples_dataset(*self.train_period)
|
| 110 |
-
return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)
|
| 111 |
-
|
| 112 |
-
def val_dataloader(self) -> DataLoader:
|
| 113 |
-
"""Construct val dataloader"""
|
| 114 |
-
if self.sample_dir is not None:
|
| 115 |
-
dataset = self._get_premade_samples_dataset("val")
|
| 116 |
-
else:
|
| 117 |
-
dataset = self._get_streamed_samples_dataset(*self.val_period)
|
| 118 |
-
return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/data/site_datamodule.py
DELETED
|
@@ -1,53 +0,0 @@
|
|
| 1 |
-
""" Data module for pytorch lightning """
|
| 2 |
-
|
| 3 |
-
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset
|
| 4 |
-
from ocf_data_sampler.torch_datasets.sample.site import SiteSample
|
| 5 |
-
from torch.utils.data import Dataset
|
| 6 |
-
|
| 7 |
-
from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class SiteDataModule(BaseDataModule):
|
| 11 |
-
"""Datamodule for training pvnet and using pvnet pipeline in `ocf-data-sampler`."""
|
| 12 |
-
|
| 13 |
-
def __init__(
|
| 14 |
-
self,
|
| 15 |
-
configuration: str | None = None,
|
| 16 |
-
sample_dir: str | None = None,
|
| 17 |
-
batch_size: int = 16,
|
| 18 |
-
num_workers: int = 0,
|
| 19 |
-
prefetch_factor: int | None = None,
|
| 20 |
-
train_period: list[str | None] = [None, None],
|
| 21 |
-
val_period: list[str | None] = [None, None],
|
| 22 |
-
):
|
| 23 |
-
"""Datamodule for training pvnet architecture.
|
| 24 |
-
|
| 25 |
-
Can also be used with pre-made batches if `sample_dir` is set.
|
| 26 |
-
|
| 27 |
-
Args:
|
| 28 |
-
configuration: Path to configuration file.
|
| 29 |
-
sample_dir: Path to the directory of pre-saved samples. Cannot be used together with
|
| 30 |
-
`configuration` or '[train/val]_period'.
|
| 31 |
-
batch_size: Batch size.
|
| 32 |
-
num_workers: Number of workers to use in multiprocess batch loading.
|
| 33 |
-
prefetch_factor: Number of data will be prefetched at the end of each worker process.
|
| 34 |
-
train_period: Date range filter for train dataloader.
|
| 35 |
-
val_period: Date range filter for val dataloader.
|
| 36 |
-
|
| 37 |
-
"""
|
| 38 |
-
super().__init__(
|
| 39 |
-
configuration=configuration,
|
| 40 |
-
sample_dir=sample_dir,
|
| 41 |
-
batch_size=batch_size,
|
| 42 |
-
num_workers=num_workers,
|
| 43 |
-
prefetch_factor=prefetch_factor,
|
| 44 |
-
train_period=train_period,
|
| 45 |
-
val_period=val_period,
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
|
| 49 |
-
return SitesDataset(self.configuration, start_time=start_time, end_time=end_time)
|
| 50 |
-
|
| 51 |
-
def _get_premade_samples_dataset(self, subdir) -> Dataset:
|
| 52 |
-
split_dir = f"{self.sample_dir}/{subdir}"
|
| 53 |
-
return PremadeSamplesDataset(split_dir, SiteSample)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/data/uk_regional_datamodule.py
DELETED
|
@@ -1,54 +0,0 @@
|
|
| 1 |
-
""" Data module for pytorch lightning """
|
| 2 |
-
|
| 3 |
-
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
|
| 4 |
-
from ocf_data_sampler.torch_datasets.sample.uk_regional import UKRegionalSample
|
| 5 |
-
from torch.utils.data import Dataset
|
| 6 |
-
|
| 7 |
-
from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class DataModule(BaseDataModule):
|
| 11 |
-
"""Datamodule for training pvnet and using pvnet pipeline in `ocf-data-sampler`."""
|
| 12 |
-
|
| 13 |
-
def __init__(
|
| 14 |
-
self,
|
| 15 |
-
configuration: str | None = None,
|
| 16 |
-
sample_dir: str | None = None,
|
| 17 |
-
batch_size: int = 16,
|
| 18 |
-
num_workers: int = 0,
|
| 19 |
-
prefetch_factor: int | None = None,
|
| 20 |
-
train_period: list[str | None] = [None, None],
|
| 21 |
-
val_period: list[str | None] = [None, None],
|
| 22 |
-
):
|
| 23 |
-
"""Datamodule for training pvnet architecture.
|
| 24 |
-
|
| 25 |
-
Can also be used with pre-made batches if `sample_dir` is set.
|
| 26 |
-
|
| 27 |
-
Args:
|
| 28 |
-
configuration: Path to configuration file.
|
| 29 |
-
sample_dir: Path to the directory of pre-saved samples. Cannot be used together with
|
| 30 |
-
`configuration` or '[train/val]_period'.
|
| 31 |
-
batch_size: Batch size.
|
| 32 |
-
num_workers: Number of workers to use in multiprocess batch loading.
|
| 33 |
-
prefetch_factor: Number of data will be prefetched at the end of each worker process.
|
| 34 |
-
train_period: Date range filter for train dataloader.
|
| 35 |
-
val_period: Date range filter for val dataloader.
|
| 36 |
-
|
| 37 |
-
"""
|
| 38 |
-
super().__init__(
|
| 39 |
-
configuration=configuration,
|
| 40 |
-
sample_dir=sample_dir,
|
| 41 |
-
batch_size=batch_size,
|
| 42 |
-
num_workers=num_workers,
|
| 43 |
-
prefetch_factor=prefetch_factor,
|
| 44 |
-
train_period=train_period,
|
| 45 |
-
val_period=val_period,
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
|
| 49 |
-
return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time)
|
| 50 |
-
|
| 51 |
-
def _get_premade_samples_dataset(self, subdir) -> Dataset:
|
| 52 |
-
split_dir = f"{self.sample_dir}/{subdir}"
|
| 53 |
-
# Returns a dict of np arrays
|
| 54 |
-
return PremadeSamplesDataset(split_dir, UKRegionalSample)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/load_model.py
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
""" Load a model from its checkpoint directory """
|
| 2 |
-
import glob
|
| 3 |
-
import os
|
| 4 |
-
|
| 5 |
-
import hydra
|
| 6 |
-
import torch
|
| 7 |
-
from pyaml_env import parse_config
|
| 8 |
-
|
| 9 |
-
from pvnet.models.ensemble import Ensemble
|
| 10 |
-
from pvnet.models.multimodal.unimodal_teacher import Model as UMTModel
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def get_model_from_checkpoints(
|
| 14 |
-
checkpoint_dir_paths: list[str],
|
| 15 |
-
val_best: bool = True,
|
| 16 |
-
):
|
| 17 |
-
"""Load a model from its checkpoint directory"""
|
| 18 |
-
is_ensemble = len(checkpoint_dir_paths) > 1
|
| 19 |
-
|
| 20 |
-
model_configs = []
|
| 21 |
-
models = []
|
| 22 |
-
data_configs = []
|
| 23 |
-
|
| 24 |
-
for path in checkpoint_dir_paths:
|
| 25 |
-
# Load the model
|
| 26 |
-
model_config = parse_config(f"{path}/model_config.yaml")
|
| 27 |
-
|
| 28 |
-
model = hydra.utils.instantiate(model_config)
|
| 29 |
-
|
| 30 |
-
if val_best:
|
| 31 |
-
# Only one epoch (best) saved per model
|
| 32 |
-
files = glob.glob(f"{path}/epoch*.ckpt")
|
| 33 |
-
if len(files) != 1:
|
| 34 |
-
raise ValueError(
|
| 35 |
-
f"Found {len(files)} checkpoints @ {path}/epoch*.ckpt. Expected one."
|
| 36 |
-
)
|
| 37 |
-
# TODO: Loading with weights_only=False is not recommended
|
| 38 |
-
checkpoint = torch.load(files[0], map_location="cpu", weights_only=False)
|
| 39 |
-
else:
|
| 40 |
-
checkpoint = torch.load(f"{path}/last.ckpt", map_location="cpu", weights_only=False)
|
| 41 |
-
|
| 42 |
-
model.load_state_dict(state_dict=checkpoint["state_dict"])
|
| 43 |
-
|
| 44 |
-
if isinstance(model, UMTModel):
|
| 45 |
-
model, model_config = model.convert_to_multimodal_model(model_config)
|
| 46 |
-
|
| 47 |
-
# Check for data config
|
| 48 |
-
data_config = f"{path}/data_config.yaml"
|
| 49 |
-
|
| 50 |
-
if os.path.isfile(data_config):
|
| 51 |
-
data_configs.append(data_config)
|
| 52 |
-
else:
|
| 53 |
-
data_configs.append(None)
|
| 54 |
-
|
| 55 |
-
model_configs.append(model_config)
|
| 56 |
-
models.append(model)
|
| 57 |
-
|
| 58 |
-
if is_ensemble:
|
| 59 |
-
model_config = {
|
| 60 |
-
"_target_": "pvnet.models.ensemble.Ensemble",
|
| 61 |
-
"model_list": model_configs,
|
| 62 |
-
}
|
| 63 |
-
model = Ensemble(model_list=models)
|
| 64 |
-
data_config = data_configs[0]
|
| 65 |
-
|
| 66 |
-
else:
|
| 67 |
-
model_config = model_configs[0]
|
| 68 |
-
model = models[0]
|
| 69 |
-
data_config = data_configs[0]
|
| 70 |
-
|
| 71 |
-
return model, model_config, data_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
"""Models for PVNet"""
|
|
|
|
|
|
pvnet/models/base_model.py
DELETED
|
@@ -1,973 +0,0 @@
|
|
| 1 |
-
"""Base model for all PVNet submodels"""
|
| 2 |
-
import copy
|
| 3 |
-
import logging
|
| 4 |
-
import os
|
| 5 |
-
import tempfile
|
| 6 |
-
import time
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from typing import Dict, Optional, Union
|
| 9 |
-
|
| 10 |
-
import hydra
|
| 11 |
-
import lightning.pytorch as pl
|
| 12 |
-
import matplotlib.pyplot as plt
|
| 13 |
-
import pandas as pd
|
| 14 |
-
import pkg_resources
|
| 15 |
-
import torch
|
| 16 |
-
import torch.nn.functional as F
|
| 17 |
-
import wandb
|
| 18 |
-
import yaml
|
| 19 |
-
from huggingface_hub import ModelCard, ModelCardData, PyTorchModelHubMixin
|
| 20 |
-
from huggingface_hub.constants import PYTORCH_WEIGHTS_NAME
|
| 21 |
-
from huggingface_hub.file_download import hf_hub_download
|
| 22 |
-
from huggingface_hub.hf_api import HfApi
|
| 23 |
-
from ocf_data_sampler.torch_datasets.sample.base import copy_batch_to_device
|
| 24 |
-
from torchvision.transforms.functional import center_crop
|
| 25 |
-
|
| 26 |
-
from pvnet.models.utils import (
|
| 27 |
-
BatchAccumulator,
|
| 28 |
-
MetricAccumulator,
|
| 29 |
-
PredAccumulator,
|
| 30 |
-
)
|
| 31 |
-
from pvnet.optimizers import AbstractOptimizer
|
| 32 |
-
from pvnet.utils import plot_batch_forecasts
|
| 33 |
-
|
| 34 |
-
DATA_CONFIG_NAME = "data_config.yaml"
|
| 35 |
-
MODEL_CONFIG_NAME = "model_config.yaml"
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
logger = logging.getLogger(__name__)
|
| 39 |
-
|
| 40 |
-
activities = [torch.profiler.ProfilerActivity.CPU]
|
| 41 |
-
if torch.cuda.is_available():
|
| 42 |
-
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def make_clean_data_config(input_path, output_path, placeholder="PLACEHOLDER"):
|
| 46 |
-
"""Resave the data config and replace the filepaths with a placeholder.
|
| 47 |
-
|
| 48 |
-
Args:
|
| 49 |
-
input_path: Path to input configuration file
|
| 50 |
-
output_path: Location to save the output configuration file
|
| 51 |
-
placeholder: String placeholder for data sources
|
| 52 |
-
"""
|
| 53 |
-
with open(input_path) as cfg:
|
| 54 |
-
config = yaml.load(cfg, Loader=yaml.FullLoader)
|
| 55 |
-
|
| 56 |
-
config["general"]["description"] = "Config for training the saved PVNet model"
|
| 57 |
-
config["general"]["name"] = "PVNet current"
|
| 58 |
-
|
| 59 |
-
for source in ["gsp", "satellite", "hrvsatellite"]:
|
| 60 |
-
if source in config["input_data"]:
|
| 61 |
-
# If not empty - i.e. if used
|
| 62 |
-
if config["input_data"][source]["zarr_path"] != "":
|
| 63 |
-
config["input_data"][source]["zarr_path"] = f"{placeholder}.zarr"
|
| 64 |
-
|
| 65 |
-
if "nwp" in config["input_data"]:
|
| 66 |
-
for source in config["input_data"]["nwp"]:
|
| 67 |
-
if config["input_data"]["nwp"][source]["zarr_path"] != "":
|
| 68 |
-
config["input_data"]["nwp"][source]["zarr_path"] = f"{placeholder}.zarr"
|
| 69 |
-
|
| 70 |
-
if "pv" in config["input_data"]:
|
| 71 |
-
for d in config["input_data"]["pv"]["pv_files_groups"]:
|
| 72 |
-
d["pv_filename"] = f"{placeholder}.netcdf"
|
| 73 |
-
d["pv_metadata_filename"] = f"{placeholder}.csv"
|
| 74 |
-
|
| 75 |
-
if "sensor" in config["input_data"]:
|
| 76 |
-
# If not empty - i.e. if used
|
| 77 |
-
if config["input_data"][source][f"{source}_filename"] != "":
|
| 78 |
-
config["input_data"][source][f"{source}_filename"] = f"{placeholder}.nc"
|
| 79 |
-
|
| 80 |
-
with open(output_path, "w") as outfile:
|
| 81 |
-
yaml.dump(config, outfile, default_flow_style=False)
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def minimize_data_config(input_path, output_path, model):
|
| 85 |
-
"""Strip out parts of the data config which aren't used by the model
|
| 86 |
-
|
| 87 |
-
Args:
|
| 88 |
-
input_path: Path to input configuration file
|
| 89 |
-
output_path: Location to save the output configuration file
|
| 90 |
-
model: The PVNet model object
|
| 91 |
-
"""
|
| 92 |
-
with open(input_path) as cfg:
|
| 93 |
-
config = yaml.load(cfg, Loader=yaml.FullLoader)
|
| 94 |
-
|
| 95 |
-
if "nwp" in config["input_data"]:
|
| 96 |
-
if not model.include_nwp:
|
| 97 |
-
del config["input_data"]["nwp"]
|
| 98 |
-
else:
|
| 99 |
-
for nwp_source in list(config["input_data"]["nwp"].keys()):
|
| 100 |
-
nwp_config = config["input_data"]["nwp"][nwp_source]
|
| 101 |
-
|
| 102 |
-
if nwp_source not in model.nwp_encoders_dict:
|
| 103 |
-
# If not used, delete this source from the config
|
| 104 |
-
del config["input_data"]["nwp"][nwp_source]
|
| 105 |
-
else:
|
| 106 |
-
# Replace the image size
|
| 107 |
-
nwp_pixel_size = model.nwp_encoders_dict[nwp_source].image_size_pixels
|
| 108 |
-
nwp_config["image_size_pixels_height"] = nwp_pixel_size
|
| 109 |
-
nwp_config["image_size_pixels_width"] = nwp_pixel_size
|
| 110 |
-
|
| 111 |
-
# Replace the interval_end_minutes minutes
|
| 112 |
-
nwp_config["interval_end_minutes"] = (
|
| 113 |
-
nwp_config["interval_start_minutes"] +
|
| 114 |
-
(model.nwp_encoders_dict[nwp_source].sequence_length - 1)
|
| 115 |
-
* nwp_config["time_resolution_minutes"]
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
-
if "satellite" in config["input_data"]:
|
| 119 |
-
if not model.include_sat:
|
| 120 |
-
del config["input_data"]["satellite"]
|
| 121 |
-
else:
|
| 122 |
-
sat_config = config["input_data"]["satellite"]
|
| 123 |
-
|
| 124 |
-
# Replace the image size
|
| 125 |
-
sat_pixel_size = model.sat_encoder.image_size_pixels
|
| 126 |
-
sat_config["image_size_pixels_height"] = sat_pixel_size
|
| 127 |
-
sat_config["image_size_pixels_width"] = sat_pixel_size
|
| 128 |
-
|
| 129 |
-
# Replace the interval_end_minutes minutes
|
| 130 |
-
sat_config["interval_end_minutes"] = (
|
| 131 |
-
sat_config["interval_start_minutes"] +
|
| 132 |
-
(model.sat_encoder.sequence_length - 1)
|
| 133 |
-
* sat_config["time_resolution_minutes"]
|
| 134 |
-
)
|
| 135 |
-
|
| 136 |
-
if "pv" in config["input_data"]:
|
| 137 |
-
if not model.include_pv:
|
| 138 |
-
del config["input_data"]["pv"]
|
| 139 |
-
|
| 140 |
-
if "gsp" in config["input_data"]:
|
| 141 |
-
gsp_config = config["input_data"]["gsp"]
|
| 142 |
-
|
| 143 |
-
# Replace the forecast minutes
|
| 144 |
-
gsp_config["interval_end_minutes"] = model.forecast_minutes
|
| 145 |
-
|
| 146 |
-
if "solar_position" in config["input_data"]:
|
| 147 |
-
solar_config = config["input_data"]["solar_position"]
|
| 148 |
-
solar_config["interval_end_minutes"] = model.forecast_minutes
|
| 149 |
-
|
| 150 |
-
with open(output_path, "w") as outfile:
|
| 151 |
-
yaml.dump(config, outfile, default_flow_style=False)
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def download_hf_hub_with_retries(
|
| 155 |
-
repo_id,
|
| 156 |
-
filename,
|
| 157 |
-
revision,
|
| 158 |
-
cache_dir,
|
| 159 |
-
force_download,
|
| 160 |
-
proxies,
|
| 161 |
-
resume_download,
|
| 162 |
-
token,
|
| 163 |
-
local_files_only,
|
| 164 |
-
max_retries=5,
|
| 165 |
-
wait_time=10,
|
| 166 |
-
):
|
| 167 |
-
"""
|
| 168 |
-
Tries to download a file from HuggingFace up to max_retries times.
|
| 169 |
-
|
| 170 |
-
Args:
|
| 171 |
-
repo_id (str): HuggingFace repo ID
|
| 172 |
-
filename (str): Name of the file to download
|
| 173 |
-
revision (str): Specific model revision
|
| 174 |
-
cache_dir (str): Cache directory
|
| 175 |
-
force_download (bool): Whether to force a new download
|
| 176 |
-
proxies (dict): Proxy settings
|
| 177 |
-
resume_download (bool): Resume interrupted downloads
|
| 178 |
-
token (str): HuggingFace auth token
|
| 179 |
-
local_files_only (bool): Use local files only
|
| 180 |
-
max_retries (int): Maximum number of retry attempts
|
| 181 |
-
wait_time (int): Wait time (in seconds) before retrying
|
| 182 |
-
|
| 183 |
-
Returns:
|
| 184 |
-
str: The local file path of the downloaded file
|
| 185 |
-
"""
|
| 186 |
-
for attempt in range(1, max_retries + 1):
|
| 187 |
-
try:
|
| 188 |
-
return hf_hub_download(
|
| 189 |
-
repo_id=repo_id,
|
| 190 |
-
filename=filename,
|
| 191 |
-
revision=revision,
|
| 192 |
-
cache_dir=cache_dir,
|
| 193 |
-
force_download=force_download,
|
| 194 |
-
proxies=proxies,
|
| 195 |
-
resume_download=resume_download,
|
| 196 |
-
token=token,
|
| 197 |
-
local_files_only=local_files_only,
|
| 198 |
-
)
|
| 199 |
-
except Exception as e:
|
| 200 |
-
if attempt == max_retries:
|
| 201 |
-
raise Exception(
|
| 202 |
-
f"Failed to download {filename} from {repo_id} after {max_retries} attempts."
|
| 203 |
-
) from e
|
| 204 |
-
logging.warning(
|
| 205 |
-
(
|
| 206 |
-
f"Attempt {attempt}/{max_retries} failed to download {filename} "
|
| 207 |
-
f"from {repo_id}. Retrying in {wait_time} seconds..."
|
| 208 |
-
)
|
| 209 |
-
)
|
| 210 |
-
time.sleep(wait_time)
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
class PVNetModelHubMixin(PyTorchModelHubMixin):
|
| 214 |
-
"""
|
| 215 |
-
Implementation of [`PyTorchModelHubMixin`] to provide model Hub upload/download capabilities.
|
| 216 |
-
"""
|
| 217 |
-
|
| 218 |
-
@classmethod
|
| 219 |
-
def from_pretrained(
|
| 220 |
-
cls,
|
| 221 |
-
*,
|
| 222 |
-
model_id: str,
|
| 223 |
-
revision: str,
|
| 224 |
-
cache_dir: Optional[Union[str, Path]] = None,
|
| 225 |
-
force_download: bool = False,
|
| 226 |
-
proxies: Optional[Dict] = None,
|
| 227 |
-
resume_download: Optional[bool] = None,
|
| 228 |
-
local_files_only: bool = False,
|
| 229 |
-
token: Union[str, bool, None] = None,
|
| 230 |
-
map_location: str = "cpu",
|
| 231 |
-
strict: bool = False,
|
| 232 |
-
):
|
| 233 |
-
"""Load Pytorch pretrained weights and return the loaded model."""
|
| 234 |
-
|
| 235 |
-
if os.path.isdir(model_id):
|
| 236 |
-
print("Loading weights from local directory")
|
| 237 |
-
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
|
| 238 |
-
config_file = os.path.join(model_id, MODEL_CONFIG_NAME)
|
| 239 |
-
else:
|
| 240 |
-
# load model file
|
| 241 |
-
model_file = download_hf_hub_with_retries(
|
| 242 |
-
repo_id=model_id,
|
| 243 |
-
filename=PYTORCH_WEIGHTS_NAME,
|
| 244 |
-
revision=revision,
|
| 245 |
-
cache_dir=cache_dir,
|
| 246 |
-
force_download=force_download,
|
| 247 |
-
proxies=proxies,
|
| 248 |
-
resume_download=resume_download,
|
| 249 |
-
token=token,
|
| 250 |
-
local_files_only=local_files_only,
|
| 251 |
-
max_retries=5,
|
| 252 |
-
wait_time=10,
|
| 253 |
-
)
|
| 254 |
-
|
| 255 |
-
# load config file
|
| 256 |
-
config_file = download_hf_hub_with_retries(
|
| 257 |
-
repo_id=model_id,
|
| 258 |
-
filename=MODEL_CONFIG_NAME,
|
| 259 |
-
revision=revision,
|
| 260 |
-
cache_dir=cache_dir,
|
| 261 |
-
force_download=force_download,
|
| 262 |
-
proxies=proxies,
|
| 263 |
-
resume_download=resume_download,
|
| 264 |
-
token=token,
|
| 265 |
-
local_files_only=local_files_only,
|
| 266 |
-
max_retries=5,
|
| 267 |
-
wait_time=10,
|
| 268 |
-
)
|
| 269 |
-
|
| 270 |
-
with open(config_file, "r") as f:
|
| 271 |
-
config = yaml.safe_load(f)
|
| 272 |
-
|
| 273 |
-
model = hydra.utils.instantiate(config)
|
| 274 |
-
|
| 275 |
-
state_dict = torch.load(model_file, map_location=torch.device(map_location))
|
| 276 |
-
model.load_state_dict(state_dict, strict=strict) # type: ignore
|
| 277 |
-
model.eval() # type: ignore
|
| 278 |
-
|
| 279 |
-
return model
|
| 280 |
-
|
| 281 |
-
@classmethod
|
| 282 |
-
def get_data_config(
|
| 283 |
-
cls,
|
| 284 |
-
model_id: str,
|
| 285 |
-
revision: str,
|
| 286 |
-
cache_dir: Optional[Union[str, Path]] = None,
|
| 287 |
-
force_download: bool = False,
|
| 288 |
-
proxies: Optional[Dict] = None,
|
| 289 |
-
resume_download: bool = False,
|
| 290 |
-
local_files_only: bool = False,
|
| 291 |
-
token: Optional[Union[str, bool]] = None,
|
| 292 |
-
):
|
| 293 |
-
"""Load data config file."""
|
| 294 |
-
if os.path.isdir(model_id):
|
| 295 |
-
print("Loading data config from local directory")
|
| 296 |
-
data_config_file = os.path.join(model_id, DATA_CONFIG_NAME)
|
| 297 |
-
else:
|
| 298 |
-
data_config_file = download_hf_hub_with_retries(
|
| 299 |
-
repo_id=model_id,
|
| 300 |
-
filename=DATA_CONFIG_NAME,
|
| 301 |
-
revision=revision,
|
| 302 |
-
cache_dir=cache_dir,
|
| 303 |
-
force_download=force_download,
|
| 304 |
-
proxies=proxies,
|
| 305 |
-
resume_download=resume_download,
|
| 306 |
-
token=token,
|
| 307 |
-
local_files_only=local_files_only,
|
| 308 |
-
max_retries=5,
|
| 309 |
-
wait_time=10,
|
| 310 |
-
)
|
| 311 |
-
|
| 312 |
-
return data_config_file
|
| 313 |
-
|
| 314 |
-
def _save_pretrained(self, save_directory: Path) -> None:
|
| 315 |
-
"""Save weights from a Pytorch model to a local directory."""
|
| 316 |
-
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
|
| 317 |
-
torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)
|
| 318 |
-
|
| 319 |
-
def save_pretrained(
|
| 320 |
-
self,
|
| 321 |
-
save_directory: Union[str, Path],
|
| 322 |
-
config: dict,
|
| 323 |
-
data_config: Optional[Union[str, Path]],
|
| 324 |
-
repo_id: Optional[str] = None,
|
| 325 |
-
push_to_hub: bool = False,
|
| 326 |
-
wandb_repo: Optional[str] = None,
|
| 327 |
-
wandb_ids: Optional[Union[list[str], str]] = None,
|
| 328 |
-
card_template_path: Optional[Path] = None,
|
| 329 |
-
**kwargs,
|
| 330 |
-
) -> Optional[str]:
|
| 331 |
-
"""
|
| 332 |
-
Save weights in local directory.
|
| 333 |
-
|
| 334 |
-
Args:
|
| 335 |
-
save_directory (`str` or `Path`):
|
| 336 |
-
Path to directory in which the model weights and configuration will be saved.
|
| 337 |
-
config (`dict`):
|
| 338 |
-
Model configuration specified as a key/value dictionary.
|
| 339 |
-
data_config (`str` or `Path`):
|
| 340 |
-
The path to the data config.
|
| 341 |
-
repo_id (`str`, *optional*):
|
| 342 |
-
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to
|
| 343 |
-
the folder name if not provided.
|
| 344 |
-
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 345 |
-
Whether or not to push your model to the HuggingFace Hub after saving it.
|
| 346 |
-
wandb_repo: Identifier of the repo on wandb.
|
| 347 |
-
wandb_ids: Identifier(s) of the model on wandb.
|
| 348 |
-
card_template_path: Path to the HuggingFace model card template. Defaults to card in
|
| 349 |
-
PVNet library if set to None.
|
| 350 |
-
kwargs:
|
| 351 |
-
Additional key word arguments passed along to the
|
| 352 |
-
[`~ModelHubMixin._from_pretrained`] method.
|
| 353 |
-
"""
|
| 354 |
-
|
| 355 |
-
save_directory = Path(save_directory)
|
| 356 |
-
save_directory.mkdir(parents=True, exist_ok=True)
|
| 357 |
-
|
| 358 |
-
# saving model weights/files
|
| 359 |
-
self._save_pretrained(save_directory)
|
| 360 |
-
|
| 361 |
-
# saving model and data config
|
| 362 |
-
if isinstance(config, dict):
|
| 363 |
-
with open(save_directory / MODEL_CONFIG_NAME, "w") as f:
|
| 364 |
-
yaml.dump(config, f, sort_keys=False, default_flow_style=False)
|
| 365 |
-
|
| 366 |
-
# Save cleaned configuration file
|
| 367 |
-
if data_config is not None:
|
| 368 |
-
new_data_config_path = save_directory / DATA_CONFIG_NAME
|
| 369 |
-
|
| 370 |
-
# Replace the input filenames with place holders
|
| 371 |
-
make_clean_data_config(data_config, new_data_config_path)
|
| 372 |
-
|
| 373 |
-
# Taylor the data config to the model being saved
|
| 374 |
-
minimize_data_config(new_data_config_path, new_data_config_path, self)
|
| 375 |
-
|
| 376 |
-
card = self.create_hugging_face_model_card(
|
| 377 |
-
repo_id, wandb_repo, wandb_ids, card_template_path
|
| 378 |
-
)
|
| 379 |
-
|
| 380 |
-
(save_directory / "README.md").write_text(str(card))
|
| 381 |
-
|
| 382 |
-
if push_to_hub:
|
| 383 |
-
api = HfApi()
|
| 384 |
-
|
| 385 |
-
api.upload_folder(
|
| 386 |
-
repo_id=repo_id,
|
| 387 |
-
repo_type="model",
|
| 388 |
-
folder_path=save_directory,
|
| 389 |
-
)
|
| 390 |
-
|
| 391 |
-
# Print the most recent commit hash
|
| 392 |
-
c = api.list_repo_commits(repo_id=repo_id, repo_type="model")[0]
|
| 393 |
-
|
| 394 |
-
message = (
|
| 395 |
-
f"The latest commit is now: \n"
|
| 396 |
-
f" date: {c.created_at} \n"
|
| 397 |
-
f" commit hash: {c.commit_id}\n"
|
| 398 |
-
f" by: {c.authors}\n"
|
| 399 |
-
f" title: {c.title}\n"
|
| 400 |
-
)
|
| 401 |
-
|
| 402 |
-
print(message)
|
| 403 |
-
|
| 404 |
-
return None
|
| 405 |
-
|
| 406 |
-
@staticmethod
|
| 407 |
-
def create_hugging_face_model_card(
|
| 408 |
-
repo_id: Optional[str] = None,
|
| 409 |
-
wandb_repo: Optional[str] = None,
|
| 410 |
-
wandb_ids: Optional[Union[list[str], str]] = None,
|
| 411 |
-
card_template_path: Optional[Path] = None,
|
| 412 |
-
) -> ModelCard:
|
| 413 |
-
"""
|
| 414 |
-
Creates Hugging Face model card
|
| 415 |
-
|
| 416 |
-
Args:
|
| 417 |
-
repo_id (`str`, *optional*):
|
| 418 |
-
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to
|
| 419 |
-
the folder name if not provided.
|
| 420 |
-
wandb_repo: Identifier of the repo on wandb.
|
| 421 |
-
wandb_ids: Identifier(s) of the model on wandb.
|
| 422 |
-
card_template_path: Path to the HuggingFace model card template. Defaults to card in
|
| 423 |
-
PVNet library if set to None.
|
| 424 |
-
|
| 425 |
-
Returns:
|
| 426 |
-
card: ModelCard - Hugging Face model card object
|
| 427 |
-
"""
|
| 428 |
-
|
| 429 |
-
# Get appropriate model card
|
| 430 |
-
model_name = repo_id.split("/")[1]
|
| 431 |
-
if model_name == "windnet_india":
|
| 432 |
-
model_card = "wind_india_model_card_template.md"
|
| 433 |
-
elif model_name == "pvnet_india":
|
| 434 |
-
model_card = "pv_india_model_card_template.md"
|
| 435 |
-
else:
|
| 436 |
-
model_card = "pv_uk_regional_model_card_template.md"
|
| 437 |
-
|
| 438 |
-
# Creating and saving model card.
|
| 439 |
-
card_data = ModelCardData(language="en", license="mit", library_name="pytorch")
|
| 440 |
-
if card_template_path is None:
|
| 441 |
-
card_template_path = (
|
| 442 |
-
f"{os.path.dirname(os.path.abspath(__file__))}/model_cards/{model_card}"
|
| 443 |
-
)
|
| 444 |
-
|
| 445 |
-
if isinstance(wandb_ids, str):
|
| 446 |
-
wandb_ids = [wandb_ids]
|
| 447 |
-
|
| 448 |
-
wandb_links = ""
|
| 449 |
-
for wandb_id in wandb_ids:
|
| 450 |
-
link = f"https://wandb.ai/{wandb_repo}/runs/{wandb_id}"
|
| 451 |
-
wandb_links += f" - [{link}]({link})\n"
|
| 452 |
-
|
| 453 |
-
# Find package versions for OCF packages
|
| 454 |
-
packages_to_display = ["pvnet", "ocf-data-sampler"]
|
| 455 |
-
packages_and_versions = {
|
| 456 |
-
package_name: pkg_resources.get_distribution(package_name).version
|
| 457 |
-
for package_name in packages_to_display
|
| 458 |
-
}
|
| 459 |
-
|
| 460 |
-
package_versions_markdown = ""
|
| 461 |
-
for package, version in packages_and_versions.items():
|
| 462 |
-
package_versions_markdown += f" - {package}=={version}\n"
|
| 463 |
-
|
| 464 |
-
return ModelCard.from_template(
|
| 465 |
-
card_data,
|
| 466 |
-
template_path=card_template_path,
|
| 467 |
-
wandb_links=wandb_links,
|
| 468 |
-
package_versions=package_versions_markdown
|
| 469 |
-
)
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
class BaseModel(pl.LightningModule, PVNetModelHubMixin):
|
| 473 |
-
"""Abstract base class for PVNet submodels"""
|
| 474 |
-
|
| 475 |
-
def __init__(
|
| 476 |
-
self,
|
| 477 |
-
history_minutes: int,
|
| 478 |
-
forecast_minutes: int,
|
| 479 |
-
optimizer: AbstractOptimizer,
|
| 480 |
-
output_quantiles: Optional[list[float]] = None,
|
| 481 |
-
target_key: str = "gsp",
|
| 482 |
-
interval_minutes: int = 30,
|
| 483 |
-
timestep_intervals_to_plot: Optional[list[int]] = None,
|
| 484 |
-
forecast_minutes_ignore: Optional[int] = 0,
|
| 485 |
-
save_validation_results_csv: Optional[bool] = False,
|
| 486 |
-
):
|
| 487 |
-
"""Abtstract base class for PVNet submodels.
|
| 488 |
-
|
| 489 |
-
Args:
|
| 490 |
-
history_minutes (int): Length of the GSP history period in minutes
|
| 491 |
-
forecast_minutes (int): Length of the GSP forecast period in minutes
|
| 492 |
-
optimizer (AbstractOptimizer): Optimizer
|
| 493 |
-
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
|
| 494 |
-
None the output is a single value.
|
| 495 |
-
target_key: The key of the target variable in the batch
|
| 496 |
-
interval_minutes: The interval in minutes between each timestep in the data
|
| 497 |
-
timestep_intervals_to_plot: Intervals, in timesteps, to plot during training
|
| 498 |
-
forecast_minutes_ignore: Number of forecast minutes to ignore when calculating losses.
|
| 499 |
-
For example if set to 60, the model doesnt predict the first 60 minutes
|
| 500 |
-
save_validation_results_csv: whether to save full csv outputs from validation results.
|
| 501 |
-
"""
|
| 502 |
-
super().__init__()
|
| 503 |
-
|
| 504 |
-
self._optimizer = optimizer
|
| 505 |
-
self._target_key = target_key
|
| 506 |
-
if timestep_intervals_to_plot is not None:
|
| 507 |
-
for interval in timestep_intervals_to_plot:
|
| 508 |
-
assert type(interval) in [list, tuple] and len(interval) == 2, ValueError(
|
| 509 |
-
f"timestep_intervals_to_plot must be a list of tuples or lists of length 2, "
|
| 510 |
-
f"but got {timestep_intervals_to_plot=}"
|
| 511 |
-
)
|
| 512 |
-
self.time_step_intervals_to_plot = timestep_intervals_to_plot
|
| 513 |
-
|
| 514 |
-
# Model must have lr to allow tuning
|
| 515 |
-
# This setting is only used when lr is tuned with callback
|
| 516 |
-
self.lr = None
|
| 517 |
-
|
| 518 |
-
self.history_minutes = history_minutes
|
| 519 |
-
self.forecast_minutes = forecast_minutes
|
| 520 |
-
self.output_quantiles = output_quantiles
|
| 521 |
-
self.interval_minutes = interval_minutes
|
| 522 |
-
self.forecast_minutes_ignore = forecast_minutes_ignore
|
| 523 |
-
|
| 524 |
-
# Number of timestemps for 30 minutely data
|
| 525 |
-
self.history_len = history_minutes // interval_minutes
|
| 526 |
-
self.forecast_len = (forecast_minutes - forecast_minutes_ignore) // interval_minutes
|
| 527 |
-
self.forecast_len_ignore = forecast_minutes_ignore // interval_minutes
|
| 528 |
-
|
| 529 |
-
self._accumulated_metrics = MetricAccumulator()
|
| 530 |
-
self._accumulated_batches = BatchAccumulator(key_to_keep=self._target_key)
|
| 531 |
-
self._accumulated_y_hat = PredAccumulator()
|
| 532 |
-
self._horizon_maes = MetricAccumulator()
|
| 533 |
-
|
| 534 |
-
# Store whether the model should use quantile regression or simply predict the mean
|
| 535 |
-
self.use_quantile_regression = self.output_quantiles is not None
|
| 536 |
-
|
| 537 |
-
# Store the number of ouput features that the model should predict for
|
| 538 |
-
if self.use_quantile_regression:
|
| 539 |
-
self.num_output_features = self.forecast_len * len(self.output_quantiles)
|
| 540 |
-
else:
|
| 541 |
-
self.num_output_features = self.forecast_len
|
| 542 |
-
|
| 543 |
-
# save all validation results to array, so we can save these to weights n biases
|
| 544 |
-
self.validation_epoch_results = []
|
| 545 |
-
self.save_validation_results_csv = save_validation_results_csv
|
| 546 |
-
|
| 547 |
-
def _adapt_batch(self, batch):
|
| 548 |
-
"""Slice batches into appropriate shapes for model.
|
| 549 |
-
|
| 550 |
-
Returns a new batch dictionary with adapted data, leaving the original batch unchanged.
|
| 551 |
-
We make some specific assumptions about the original batch and the derived sliced batch:
|
| 552 |
-
- We are only limiting the future projections. I.e. we are never shrinking the batch from
|
| 553 |
-
the left hand side of the time axis, only slicing it from the right
|
| 554 |
-
- We are only shrinking the spatial crop of the satellite and NWP data
|
| 555 |
-
|
| 556 |
-
"""
|
| 557 |
-
# Create a copy of the batch to avoid modifying the original
|
| 558 |
-
new_batch = {key: copy.deepcopy(value) for key, value in batch.items()}
|
| 559 |
-
|
| 560 |
-
if "gsp" in new_batch.keys():
|
| 561 |
-
# Slice off the end of the GSP data
|
| 562 |
-
gsp_len = self.forecast_len + self.history_len + 1
|
| 563 |
-
new_batch["gsp"] = new_batch["gsp"][:, :gsp_len]
|
| 564 |
-
new_batch["gsp_time_utc"] = new_batch["gsp_time_utc"][:, :gsp_len]
|
| 565 |
-
|
| 566 |
-
if self.include_sat:
|
| 567 |
-
# Slice off the end of the satellite data and spatially crop
|
| 568 |
-
# Shape: batch_size, seq_length, channel, height, width
|
| 569 |
-
new_batch["satellite_actual"] = center_crop(
|
| 570 |
-
new_batch["satellite_actual"][:, : self.sat_sequence_len],
|
| 571 |
-
output_size=self.sat_encoder.image_size_pixels,
|
| 572 |
-
)
|
| 573 |
-
|
| 574 |
-
if self.include_nwp:
|
| 575 |
-
# Slice off the end of the NWP data and spatially crop
|
| 576 |
-
for nwp_source in self.nwp_encoders_dict:
|
| 577 |
-
# shape: batch_size, seq_len, n_chans, height, width
|
| 578 |
-
new_batch["nwp"][nwp_source]["nwp"] = center_crop(
|
| 579 |
-
new_batch["nwp"][nwp_source]["nwp"],
|
| 580 |
-
output_size=self.nwp_encoders_dict[nwp_source].image_size_pixels,
|
| 581 |
-
)[:, : self.nwp_encoders_dict[nwp_source].sequence_length]
|
| 582 |
-
|
| 583 |
-
if self.include_sun:
|
| 584 |
-
sun_len = self.forecast_len + self.history_len + 1
|
| 585 |
-
# Slice off end of solar coords
|
| 586 |
-
for s in ["solar_azimuth", "solar_elevation"]:
|
| 587 |
-
if s in new_batch.keys():
|
| 588 |
-
new_batch[s] = new_batch[s][:, :sun_len]
|
| 589 |
-
|
| 590 |
-
return new_batch
|
| 591 |
-
|
| 592 |
-
def transfer_batch_to_device(self, batch, device, dataloader_idx):
|
| 593 |
-
"""Method to move custom batches to a given device"""
|
| 594 |
-
return copy_batch_to_device(batch, device)
|
| 595 |
-
|
| 596 |
-
def _quantiles_to_prediction(self, y_quantiles):
|
| 597 |
-
"""
|
| 598 |
-
Convert network prediction into a point prediction.
|
| 599 |
-
|
| 600 |
-
Note:
|
| 601 |
-
Implementation copied from:
|
| 602 |
-
https://pytorch-forecasting.readthedocs.io/en/stable/_modules/pytorch_forecasting
|
| 603 |
-
/metrics/quantile.html#QuantileLoss.loss
|
| 604 |
-
|
| 605 |
-
Args:
|
| 606 |
-
y_quantiles: Quantile prediction of network
|
| 607 |
-
|
| 608 |
-
Returns:
|
| 609 |
-
torch.Tensor: Point prediction
|
| 610 |
-
"""
|
| 611 |
-
# y_quantiles Shape: batch_size, seq_length, num_quantiles
|
| 612 |
-
idx = self.output_quantiles.index(0.5)
|
| 613 |
-
y_median = y_quantiles[..., idx]
|
| 614 |
-
return y_median
|
| 615 |
-
|
| 616 |
-
def _calculate_quantile_loss(self, y_quantiles, y):
|
| 617 |
-
"""Calculate quantile loss.
|
| 618 |
-
|
| 619 |
-
Note:
|
| 620 |
-
Implementation copied from:
|
| 621 |
-
https://pytorch-forecasting.readthedocs.io/en/stable/_modules/pytorch_forecasting
|
| 622 |
-
/metrics/quantile.html#QuantileLoss.loss
|
| 623 |
-
|
| 624 |
-
Args:
|
| 625 |
-
y_quantiles: Quantile prediction of network
|
| 626 |
-
y: Target values
|
| 627 |
-
|
| 628 |
-
Returns:
|
| 629 |
-
Quantile loss
|
| 630 |
-
"""
|
| 631 |
-
# calculate quantile loss
|
| 632 |
-
losses = []
|
| 633 |
-
for i, q in enumerate(self.output_quantiles):
|
| 634 |
-
errors = y - y_quantiles[..., i]
|
| 635 |
-
losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
|
| 636 |
-
losses = 2 * torch.cat(losses, dim=2)
|
| 637 |
-
|
| 638 |
-
return losses.mean()
|
| 639 |
-
|
| 640 |
-
def _calculate_common_losses(self, y, y_hat):
|
| 641 |
-
"""Calculate losses common to train, and val"""
|
| 642 |
-
|
| 643 |
-
losses = {}
|
| 644 |
-
|
| 645 |
-
if self.use_quantile_regression:
|
| 646 |
-
losses["quantile_loss"] = self._calculate_quantile_loss(y_hat, y)
|
| 647 |
-
y_hat = self._quantiles_to_prediction(y_hat)
|
| 648 |
-
|
| 649 |
-
# calculate mse, mae
|
| 650 |
-
mse_loss = F.mse_loss(y_hat, y)
|
| 651 |
-
mae_loss = F.l1_loss(y_hat, y)
|
| 652 |
-
|
| 653 |
-
# TODO: Compute correlation coef using np.corrcoef(tensor with
|
| 654 |
-
# shape (2, num_timesteps))[0, 1] on each example, and taking
|
| 655 |
-
# the mean across the batch?
|
| 656 |
-
losses.update(
|
| 657 |
-
{
|
| 658 |
-
"MSE": mse_loss,
|
| 659 |
-
"MAE": mae_loss,
|
| 660 |
-
}
|
| 661 |
-
)
|
| 662 |
-
|
| 663 |
-
return losses
|
| 664 |
-
|
| 665 |
-
def _step_mae_and_mse(self, y, y_hat, dict_key_root):
|
| 666 |
-
"""Calculate the MSE and MAE at each forecast step"""
|
| 667 |
-
losses = {}
|
| 668 |
-
|
| 669 |
-
mse_each_step = torch.mean((y_hat - y) ** 2, dim=0)
|
| 670 |
-
mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0)
|
| 671 |
-
|
| 672 |
-
losses.update({f"MSE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mse_each_step)})
|
| 673 |
-
losses.update({f"MAE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mae_each_step)})
|
| 674 |
-
|
| 675 |
-
return losses
|
| 676 |
-
|
| 677 |
-
def _calculate_val_losses(self, y, y_hat):
|
| 678 |
-
"""Calculate additional validation losses"""
|
| 679 |
-
|
| 680 |
-
losses = {}
|
| 681 |
-
|
| 682 |
-
if self.use_quantile_regression:
|
| 683 |
-
# Add fraction below each quantile for calibration
|
| 684 |
-
for i, quantile in enumerate(self.output_quantiles):
|
| 685 |
-
below_quant = y <= y_hat[..., i]
|
| 686 |
-
# Mask values small values, which are dominated by night
|
| 687 |
-
mask = y >= 0.01
|
| 688 |
-
losses[f"fraction_below_{quantile}_quantile"] = (below_quant[mask]).float().mean()
|
| 689 |
-
|
| 690 |
-
# Take median value for remaining metric calculations
|
| 691 |
-
y_hat = self._quantiles_to_prediction(y_hat)
|
| 692 |
-
|
| 693 |
-
# Log the loss at each time horizon
|
| 694 |
-
losses.update(self._step_mae_and_mse(y, y_hat, dict_key_root="horizon"))
|
| 695 |
-
|
| 696 |
-
# Log the persistance losses
|
| 697 |
-
y_persist = y[:, -1].unsqueeze(1).expand(-1, self.forecast_len)
|
| 698 |
-
losses["MAE_persistence/val"] = F.l1_loss(y_persist, y)
|
| 699 |
-
losses["MSE_persistence/val"] = F.mse_loss(y_persist, y)
|
| 700 |
-
|
| 701 |
-
# Log persistance loss at each time horizon
|
| 702 |
-
losses.update(self._step_mae_and_mse(y, y_persist, dict_key_root="persistence"))
|
| 703 |
-
return losses
|
| 704 |
-
|
| 705 |
-
def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):
|
| 706 |
-
"""Internal function to accumulate training batches and log results.
|
| 707 |
-
|
| 708 |
-
This is used when accummulating grad batches. Should make the variability in logged training
|
| 709 |
-
step metrics indpendent on whether we accumulate N batches of size B or just use a larger
|
| 710 |
-
batch size of N*B with no accumulaion.
|
| 711 |
-
"""
|
| 712 |
-
|
| 713 |
-
losses = {k: v.detach().cpu() for k, v in losses.items()}
|
| 714 |
-
y_hat = y_hat.detach().cpu()
|
| 715 |
-
|
| 716 |
-
self._accumulated_metrics.append(losses)
|
| 717 |
-
self._accumulated_batches.append(batch)
|
| 718 |
-
self._accumulated_y_hat.append(y_hat)
|
| 719 |
-
|
| 720 |
-
if not self.trainer.fit_loop._should_accumulate():
|
| 721 |
-
losses = self._accumulated_metrics.flush()
|
| 722 |
-
batch = self._accumulated_batches.flush()
|
| 723 |
-
y_hat = self._accumulated_y_hat.flush()
|
| 724 |
-
|
| 725 |
-
self.log_dict(
|
| 726 |
-
losses,
|
| 727 |
-
on_step=True,
|
| 728 |
-
on_epoch=True,
|
| 729 |
-
)
|
| 730 |
-
|
| 731 |
-
# Number of accumulated grad batches
|
| 732 |
-
grad_batch_num = (batch_idx + 1) / self.trainer.accumulate_grad_batches
|
| 733 |
-
|
| 734 |
-
# We only create the figure every 8 log steps
|
| 735 |
-
# This was reduced as it was creating figures too often
|
| 736 |
-
if grad_batch_num % (8 * self.trainer.log_every_n_steps) == 0:
|
| 737 |
-
fig = plot_batch_forecasts(
|
| 738 |
-
batch,
|
| 739 |
-
y_hat,
|
| 740 |
-
batch_idx,
|
| 741 |
-
quantiles=self.output_quantiles,
|
| 742 |
-
key_to_plot=self._target_key,
|
| 743 |
-
)
|
| 744 |
-
fig.savefig("latest_logged_train_batch.png")
|
| 745 |
-
plt.close(fig)
|
| 746 |
-
|
| 747 |
-
def training_step(self, batch, batch_idx):
|
| 748 |
-
"""Run training step"""
|
| 749 |
-
y_hat = self(batch)
|
| 750 |
-
|
| 751 |
-
# Batch is adapted in the model forward method, but needs to be adapted here too
|
| 752 |
-
batch = self._adapt_batch(batch)
|
| 753 |
-
|
| 754 |
-
y = batch[self._target_key][:, -self.forecast_len :]
|
| 755 |
-
|
| 756 |
-
losses = self._calculate_common_losses(y, y_hat)
|
| 757 |
-
losses = {f"{k}/train": v for k, v in losses.items()}
|
| 758 |
-
|
| 759 |
-
self._training_accumulate_log(batch, batch_idx, losses, y_hat)
|
| 760 |
-
|
| 761 |
-
if self.use_quantile_regression:
|
| 762 |
-
opt_target = losses["quantile_loss/train"]
|
| 763 |
-
else:
|
| 764 |
-
opt_target = losses["MAE/train"]
|
| 765 |
-
return opt_target
|
| 766 |
-
|
| 767 |
-
def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, plot_suffix):
|
| 768 |
-
"""Log forecast plot to wandb"""
|
| 769 |
-
fig = plot_batch_forecasts(
|
| 770 |
-
batch,
|
| 771 |
-
y_hat,
|
| 772 |
-
quantiles=self.output_quantiles,
|
| 773 |
-
key_to_plot=self._target_key,
|
| 774 |
-
)
|
| 775 |
-
|
| 776 |
-
plot_name = f"val_forecast_samples/batch_idx_{accum_batch_num}_{plot_suffix}"
|
| 777 |
-
|
| 778 |
-
try:
|
| 779 |
-
self.logger.experiment.log({plot_name: wandb.Image(fig)})
|
| 780 |
-
except Exception as e:
|
| 781 |
-
print(f"Failed to log {plot_name} to wandb")
|
| 782 |
-
print(e)
|
| 783 |
-
plt.close(fig)
|
| 784 |
-
|
| 785 |
-
def _log_validation_results(self, batch, y_hat, accum_batch_num):
|
| 786 |
-
"""Append validation results to self.validation_epoch_results"""
|
| 787 |
-
|
| 788 |
-
# get truth values, shape (b, forecast_len)
|
| 789 |
-
y = batch[self._target_key][:, -self.forecast_len :]
|
| 790 |
-
y = y.detach().cpu().numpy()
|
| 791 |
-
batch_size = y.shape[0]
|
| 792 |
-
|
| 793 |
-
# get prediction values, shape (b, forecast_len, quantiles?)
|
| 794 |
-
y_hat = y_hat.detach().cpu().numpy()
|
| 795 |
-
|
| 796 |
-
# get time_utc, shape (b, forecast_len)
|
| 797 |
-
time_utc_key = f"{self._target_key}_time_utc"
|
| 798 |
-
time_utc = batch[time_utc_key][:, -self.forecast_len :].detach().cpu().numpy()
|
| 799 |
-
|
| 800 |
-
# get target id and change from (b,1) to (b,)
|
| 801 |
-
id_key = f"{self._target_key}_id"
|
| 802 |
-
target_id = batch[id_key].detach().cpu().numpy()
|
| 803 |
-
target_id = target_id.squeeze()
|
| 804 |
-
|
| 805 |
-
for i in range(batch_size):
|
| 806 |
-
y_i = y[i]
|
| 807 |
-
y_hat_i = y_hat[i]
|
| 808 |
-
time_utc_i = time_utc[i]
|
| 809 |
-
target_id_i = target_id[i]
|
| 810 |
-
|
| 811 |
-
results_dict = {
|
| 812 |
-
"y": y_i,
|
| 813 |
-
"time_utc": time_utc_i,
|
| 814 |
-
}
|
| 815 |
-
if self.use_quantile_regression:
|
| 816 |
-
results_dict.update(
|
| 817 |
-
{f"y_quantile_{q}": y_hat_i[:, i] for i, q in enumerate(self.output_quantiles)}
|
| 818 |
-
)
|
| 819 |
-
else:
|
| 820 |
-
results_dict["y_hat"] = y_hat_i
|
| 821 |
-
|
| 822 |
-
results_df = pd.DataFrame(results_dict)
|
| 823 |
-
results_df["id"] = target_id_i
|
| 824 |
-
results_df["batch_idx"] = accum_batch_num
|
| 825 |
-
results_df["example_idx"] = i
|
| 826 |
-
|
| 827 |
-
self.validation_epoch_results.append(results_df)
|
| 828 |
-
|
| 829 |
-
def validation_step(self, batch: dict, batch_idx):
|
| 830 |
-
"""Run validation step"""
|
| 831 |
-
|
| 832 |
-
accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches
|
| 833 |
-
|
| 834 |
-
y_hat = self(batch)
|
| 835 |
-
# Batch is adapted in the model forward method, but needs to be adapted here too
|
| 836 |
-
batch = self._adapt_batch(batch)
|
| 837 |
-
|
| 838 |
-
y = batch[self._target_key][:, -self.forecast_len :]
|
| 839 |
-
|
| 840 |
-
if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
|
| 841 |
-
self._log_validation_results(batch, y_hat, accum_batch_num)
|
| 842 |
-
|
| 843 |
-
# Expand persistence to be the same shape as y
|
| 844 |
-
losses = self._calculate_common_losses(y, y_hat)
|
| 845 |
-
losses.update(self._calculate_val_losses(y, y_hat))
|
| 846 |
-
|
| 847 |
-
# Store these to make horizon accuracy plot
|
| 848 |
-
self._horizon_maes.append(
|
| 849 |
-
{i: losses[f"MAE_horizon/step_{i:03}"].cpu().numpy() for i in range(self.forecast_len)}
|
| 850 |
-
)
|
| 851 |
-
|
| 852 |
-
logged_losses = {f"{k}/val": v for k, v in losses.items()}
|
| 853 |
-
|
| 854 |
-
self.log_dict(
|
| 855 |
-
logged_losses,
|
| 856 |
-
on_step=False,
|
| 857 |
-
on_epoch=True,
|
| 858 |
-
)
|
| 859 |
-
|
| 860 |
-
# Make plots only if using wandb logger
|
| 861 |
-
if isinstance(self.logger, pl.loggers.WandbLogger) and accum_batch_num in [0, 1]:
|
| 862 |
-
# Store these temporarily under self
|
| 863 |
-
if not hasattr(self, "_val_y_hats"):
|
| 864 |
-
self._val_y_hats = PredAccumulator()
|
| 865 |
-
self._val_batches = BatchAccumulator(key_to_keep=self._target_key)
|
| 866 |
-
|
| 867 |
-
self._val_y_hats.append(y_hat)
|
| 868 |
-
self._val_batches.append(batch)
|
| 869 |
-
|
| 870 |
-
# if batch has accumulated
|
| 871 |
-
if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
|
| 872 |
-
y_hat = self._val_y_hats.flush()
|
| 873 |
-
batch = self._val_batches.flush()
|
| 874 |
-
|
| 875 |
-
self._log_forecast_plot(
|
| 876 |
-
batch,
|
| 877 |
-
y_hat,
|
| 878 |
-
accum_batch_num,
|
| 879 |
-
timesteps_to_plot=None,
|
| 880 |
-
plot_suffix="all",
|
| 881 |
-
)
|
| 882 |
-
|
| 883 |
-
if self.time_step_intervals_to_plot is not None:
|
| 884 |
-
for interval in self.time_step_intervals_to_plot:
|
| 885 |
-
self._log_forecast_plot(
|
| 886 |
-
batch,
|
| 887 |
-
y_hat,
|
| 888 |
-
accum_batch_num,
|
| 889 |
-
timesteps_to_plot=interval,
|
| 890 |
-
plot_suffix=f"timestep_{interval}",
|
| 891 |
-
)
|
| 892 |
-
|
| 893 |
-
del self._val_y_hats
|
| 894 |
-
del self._val_batches
|
| 895 |
-
|
| 896 |
-
return logged_losses
|
| 897 |
-
|
| 898 |
-
def on_validation_epoch_end(self):
|
| 899 |
-
"""Run on epoch end"""
|
| 900 |
-
|
| 901 |
-
try:
|
| 902 |
-
# join together validation results, and save to wandb
|
| 903 |
-
validation_results_df = pd.concat(self.validation_epoch_results)
|
| 904 |
-
validation_results_df["error"] = (
|
| 905 |
-
validation_results_df["y"] - validation_results_df["y_quantile_0.5"]
|
| 906 |
-
)
|
| 907 |
-
|
| 908 |
-
if isinstance(self.logger, pl.loggers.WandbLogger):
|
| 909 |
-
# log error distribution metrics
|
| 910 |
-
wandb.log(
|
| 911 |
-
{
|
| 912 |
-
"2nd_percentile_median_forecast_error": validation_results_df[
|
| 913 |
-
"error"
|
| 914 |
-
].quantile(0.02),
|
| 915 |
-
"5th_percentile_median_forecast_error": validation_results_df[
|
| 916 |
-
"error"
|
| 917 |
-
].quantile(0.05),
|
| 918 |
-
"95th_percentile_median_forecast_error": validation_results_df[
|
| 919 |
-
"error"
|
| 920 |
-
].quantile(0.95),
|
| 921 |
-
"98th_percentile_median_forecast_error": validation_results_df[
|
| 922 |
-
"error"
|
| 923 |
-
].quantile(0.98),
|
| 924 |
-
"95th_percentile_median_forecast_absolute_error": abs(
|
| 925 |
-
validation_results_df["error"]
|
| 926 |
-
).quantile(0.95),
|
| 927 |
-
"98th_percentile_median_forecast_absolute_error": abs(
|
| 928 |
-
validation_results_df["error"]
|
| 929 |
-
).quantile(0.98),
|
| 930 |
-
}
|
| 931 |
-
)
|
| 932 |
-
# saving validation result csvs
|
| 933 |
-
if self.save_validation_results_csv:
|
| 934 |
-
with tempfile.TemporaryDirectory() as tempdir:
|
| 935 |
-
filename = os.path.join(tempdir, f"validation_results_{self.current_epoch}.csv")
|
| 936 |
-
validation_results_df.to_csv(filename, index=False)
|
| 937 |
-
|
| 938 |
-
# make and log wand artifact
|
| 939 |
-
validation_artifact = wandb.Artifact(
|
| 940 |
-
f"validation_results_epoch_{self.current_epoch}", type="dataset"
|
| 941 |
-
)
|
| 942 |
-
validation_artifact.add_file(filename)
|
| 943 |
-
wandb.log_artifact(validation_artifact)
|
| 944 |
-
|
| 945 |
-
except Exception as e:
|
| 946 |
-
print("Failed to log validation results to wandb")
|
| 947 |
-
print(e)
|
| 948 |
-
|
| 949 |
-
self.validation_epoch_results = []
|
| 950 |
-
horizon_maes_dict = self._horizon_maes.flush()
|
| 951 |
-
|
| 952 |
-
# Create the horizon accuracy curve
|
| 953 |
-
if isinstance(self.logger, pl.loggers.WandbLogger):
|
| 954 |
-
per_step_losses = [[i, horizon_maes_dict[i]] for i in range(self.forecast_len)]
|
| 955 |
-
try:
|
| 956 |
-
table = wandb.Table(data=per_step_losses, columns=["horizon_step", "MAE"])
|
| 957 |
-
wandb.log(
|
| 958 |
-
{
|
| 959 |
-
"horizon_loss_curve": wandb.plot.line(
|
| 960 |
-
table, "horizon_step", "MAE", title="Horizon loss curve"
|
| 961 |
-
)
|
| 962 |
-
},
|
| 963 |
-
)
|
| 964 |
-
except Exception as e:
|
| 965 |
-
print("Failed to log horizon_loss_curve to wandb")
|
| 966 |
-
print(e)
|
| 967 |
-
|
| 968 |
-
def configure_optimizers(self):
|
| 969 |
-
"""Configure the optimizers using learning rate found with LR finder if used"""
|
| 970 |
-
if self.lr is not None:
|
| 971 |
-
# Use learning rate found by learning rate finder callback
|
| 972 |
-
self._optimizer.lr = self.lr
|
| 973 |
-
return self._optimizer(self)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/baseline/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
"""Baselines"""
|
|
|
|
|
|
pvnet/models/baseline/last_value.py
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
"""Persistence model"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
import pvnet
|
| 5 |
-
from pvnet.models.base_model import BaseModel
|
| 6 |
-
from pvnet.optimizers import AbstractOptimizer
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class Model(BaseModel):
|
| 10 |
-
"""Simple baseline model that takes the last gsp yield value and copies it forward."""
|
| 11 |
-
|
| 12 |
-
name = "last_value"
|
| 13 |
-
|
| 14 |
-
def __init__(
|
| 15 |
-
self,
|
| 16 |
-
forecast_minutes: int = 12,
|
| 17 |
-
history_minutes: int = 6,
|
| 18 |
-
optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
|
| 19 |
-
):
|
| 20 |
-
"""Simple baseline model that takes the last gsp yield value and copies it forward.
|
| 21 |
-
|
| 22 |
-
Args:
|
| 23 |
-
history_minutes (int): Length of the GSP history period in minutes
|
| 24 |
-
forecast_minutes (int): Length of the GSP forecast period in minutes
|
| 25 |
-
optimizer (AbstractOptimizer): Optimizer
|
| 26 |
-
"""
|
| 27 |
-
|
| 28 |
-
super().__init__(history_minutes, forecast_minutes, optimizer)
|
| 29 |
-
self.save_hyperparameters()
|
| 30 |
-
|
| 31 |
-
def forward(self, x: dict):
|
| 32 |
-
"""Run model forward on dict batch of data"""
|
| 33 |
-
# Shape: batch_size, seq_length, n_sites
|
| 34 |
-
gsp_yield = x["gsp"]
|
| 35 |
-
|
| 36 |
-
# take the last value non forecaster value and the first in the pv yeild
|
| 37 |
-
# (this is the pv site we are preditcting for)
|
| 38 |
-
y_hat = gsp_yield[:, -self.forecast_len - 1]
|
| 39 |
-
|
| 40 |
-
# expand the last valid forward n predict steps
|
| 41 |
-
out = y_hat.unsqueeze(1).repeat(1, self.forecast_len)
|
| 42 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/baseline/readme.md
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
# Baseline Models
|
| 2 |
-
|
| 3 |
-
- `last_value` - Forecast the sample last historical PV yeild for every forecast step
|
| 4 |
-
- `single_value` - Learns a single value estimate and predicts this value for every input and every
|
| 5 |
-
forecast step.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/baseline/single_value.py
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
"""Average value model"""
|
| 2 |
-
import torch
|
| 3 |
-
from torch import nn
|
| 4 |
-
|
| 5 |
-
import pvnet
|
| 6 |
-
from pvnet.models.base_model import BaseModel
|
| 7 |
-
from pvnet.optimizers import AbstractOptimizer
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class Model(BaseModel):
|
| 11 |
-
"""Simple baseline model that predicts always the same value."""
|
| 12 |
-
|
| 13 |
-
name = "single_value"
|
| 14 |
-
|
| 15 |
-
def __init__(
|
| 16 |
-
self,
|
| 17 |
-
forecast_minutes: int = 120,
|
| 18 |
-
history_minutes: int = 60,
|
| 19 |
-
optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
|
| 20 |
-
):
|
| 21 |
-
"""Simple baseline model that predicts always the same value.
|
| 22 |
-
|
| 23 |
-
Args:
|
| 24 |
-
history_minutes (int): Length of the GSP history period in minutes
|
| 25 |
-
forecast_minutes (int): Length of the GSP forecast period in minutes
|
| 26 |
-
optimizer (AbstractOptimizer): Optimizer
|
| 27 |
-
"""
|
| 28 |
-
super().__init__(history_minutes, forecast_minutes, optimizer)
|
| 29 |
-
self._value = nn.Parameter(torch.zeros(1), requires_grad=True)
|
| 30 |
-
self.save_hyperparameters()
|
| 31 |
-
|
| 32 |
-
def forward(self, x: dict):
|
| 33 |
-
"""Run model forward on dict batch of data"""
|
| 34 |
-
# Returns a single value at all steps
|
| 35 |
-
y_hat = torch.zeros_like(x["gsp"][:, : self.forecast_len]) + self._value
|
| 36 |
-
return y_hat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/ensemble.py
DELETED
|
@@ -1,74 +0,0 @@
|
|
| 1 |
-
"""Model which uses mutliple prediction heads"""
|
| 2 |
-
from typing import Optional
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
|
| 7 |
-
from pvnet.models.base_model import BaseModel
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class Ensemble(BaseModel):
|
| 11 |
-
"""Ensemble of PVNet models"""
|
| 12 |
-
|
| 13 |
-
def __init__(
|
| 14 |
-
self,
|
| 15 |
-
model_list: list[BaseModel],
|
| 16 |
-
weights: Optional[list[float]] = None,
|
| 17 |
-
):
|
| 18 |
-
"""Ensemble of PVNet models
|
| 19 |
-
|
| 20 |
-
Args:
|
| 21 |
-
model_list: A list of PVNet models to ensemble
|
| 22 |
-
weights: A list of weighting to apply to each model. If None, the models are weighted
|
| 23 |
-
equally.
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
# Surface check all the models are compatible
|
| 27 |
-
output_quantiles = []
|
| 28 |
-
history_minutes = []
|
| 29 |
-
forecast_minutes = []
|
| 30 |
-
target_key = []
|
| 31 |
-
interval_minutes = []
|
| 32 |
-
|
| 33 |
-
# Get some model properties from each model
|
| 34 |
-
for model in model_list:
|
| 35 |
-
output_quantiles.append(model.output_quantiles)
|
| 36 |
-
history_minutes.append(model.history_minutes)
|
| 37 |
-
forecast_minutes.append(model.forecast_minutes)
|
| 38 |
-
target_key.append(model._target_key)
|
| 39 |
-
interval_minutes.append(model.interval_minutes)
|
| 40 |
-
|
| 41 |
-
# Check these properties are all the same
|
| 42 |
-
for param_list in [
|
| 43 |
-
output_quantiles,
|
| 44 |
-
history_minutes,
|
| 45 |
-
forecast_minutes,
|
| 46 |
-
target_key,
|
| 47 |
-
interval_minutes,
|
| 48 |
-
]:
|
| 49 |
-
assert all([p == param_list[0] for p in param_list]), param_list
|
| 50 |
-
|
| 51 |
-
super().__init__(
|
| 52 |
-
history_minutes=history_minutes[0],
|
| 53 |
-
forecast_minutes=forecast_minutes[0],
|
| 54 |
-
optimizer=None,
|
| 55 |
-
output_quantiles=output_quantiles[0],
|
| 56 |
-
target_key=target_key[0],
|
| 57 |
-
interval_minutes=interval_minutes[0],
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
self.model_list = nn.ModuleList(model_list)
|
| 61 |
-
|
| 62 |
-
if weights is None:
|
| 63 |
-
weights = torch.ones(len(model_list)) / len(model_list)
|
| 64 |
-
else:
|
| 65 |
-
assert len(weights) == len(model_list)
|
| 66 |
-
weights = torch.Tensor(weights) / sum(weights)
|
| 67 |
-
self.weights = nn.Parameter(weights, requires_grad=False)
|
| 68 |
-
|
| 69 |
-
def forward(self, batch):
|
| 70 |
-
"""Run the model forward"""
|
| 71 |
-
y_hat = 0
|
| 72 |
-
for weight, model in zip(self.weights, self.model_list):
|
| 73 |
-
y_hat = model(batch) * weight + y_hat
|
| 74 |
-
return y_hat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/model_cards/pv_india_model_card_template.md
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
{{ card_data }}
|
| 3 |
-
---
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# PVNet India
|
| 11 |
-
|
| 12 |
-
## Model Description
|
| 13 |
-
|
| 14 |
-
<!-- Provide a longer summary of what this model is/does. -->
|
| 15 |
-
This model class uses numerical weather predictions from providers such as ECMWF to forecast the PV power in North West India over the next 48 hours. More information can be found in the model repo [1] and experimental notes [here](https://github.com/openclimatefix/PVNet/tree/main/experiments/india).
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
- **Developed by:** openclimatefix
|
| 19 |
-
- **Model type:** Fusion model
|
| 20 |
-
- **Language(s) (NLP):** en
|
| 21 |
-
- **License:** mit
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
# Training Details
|
| 25 |
-
|
| 26 |
-
## Data
|
| 27 |
-
|
| 28 |
-
<!-- This should link to a Data Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 29 |
-
|
| 30 |
-
The model is trained on data from 2019-2022 and validated on data from 2022-2023. See experimental notes [here](https://github.com/openclimatefix/PVNet/tree/main/experiments/india)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
### Preprocessing
|
| 34 |
-
|
| 35 |
-
Data is prepared with the `ocf_data_sampler/torch_datasets/datasets/site` Dataset [2].
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
## Results
|
| 39 |
-
|
| 40 |
-
The training logs for the current model can be found here:
|
| 41 |
-
{{ wandb_links }}
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
### Hardware
|
| 45 |
-
|
| 46 |
-
Trained on a single NVIDIA Tesla T4
|
| 47 |
-
|
| 48 |
-
### Software
|
| 49 |
-
|
| 50 |
-
This model was trained using the following Open Climate Fix packages:
|
| 51 |
-
|
| 52 |
-
- [1] https://github.com/openclimatefix/PVNet
|
| 53 |
-
- [2] https://github.com/openclimatefix/ocf-data-sampler
|
| 54 |
-
|
| 55 |
-
The versions of these packages can be found below:
|
| 56 |
-
{{ package_versions }}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/model_cards/pv_uk_regional_model_card_template.md
DELETED
|
@@ -1,59 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
{{ card_data }}
|
| 3 |
-
---
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# PVNet2
|
| 11 |
-
|
| 12 |
-
## Model Description
|
| 13 |
-
|
| 14 |
-
<!-- Provide a longer summary of what this model is/does. -->
|
| 15 |
-
This model class uses satellite data, numerical weather predictions, and recent Grid Service Point( GSP) PV power output to forecast the near-term (~8 hours) PV power output at all GSPs. More information can be found in the model repo [1] and experimental notes in [this google doc](https://docs.google.com/document/d/1fbkfkBzp16WbnCg7RDuRDvgzInA6XQu3xh4NCjV-WDA/edit?usp=sharing).
|
| 16 |
-
|
| 17 |
-
- **Developed by:** openclimatefix
|
| 18 |
-
- **Model type:** Fusion model
|
| 19 |
-
- **Language(s) (NLP):** en
|
| 20 |
-
- **License:** mit
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
# Training Details
|
| 24 |
-
|
| 25 |
-
## Data
|
| 26 |
-
|
| 27 |
-
<!-- This should link to a Data Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 28 |
-
|
| 29 |
-
The model is trained on data from 2019-2022 and validated on data from 2022-2023. See experimental notes in the [the google doc](https://docs.google.com/document/d/1fbkfkBzp16WbnCg7RDuRDvgzInA6XQu3xh4NCjV-WDA/edit?usp=sharing) for more details.
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
### Preprocessing
|
| 33 |
-
|
| 34 |
-
Data is prepared with the `ocf_data_sampler/torch_datasets/datasets/pvnet_uk` Dataset [2].
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
## Results
|
| 38 |
-
|
| 39 |
-
The training logs for the current model can be found here:
|
| 40 |
-
{{ wandb_links }}
|
| 41 |
-
|
| 42 |
-
The training logs for all model runs of PVNet2 can be found [here](https://wandb.ai/openclimatefix/pvnet2.1).
|
| 43 |
-
|
| 44 |
-
Some experimental notes can be found at in [the google doc](https://docs.google.com/document/d/1fbkfkBzp16WbnCg7RDuRDvgzInA6XQu3xh4NCjV-WDA/edit?usp=sharing)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
### Hardware
|
| 48 |
-
|
| 49 |
-
Trained on a single NVIDIA Tesla T4
|
| 50 |
-
|
| 51 |
-
### Software
|
| 52 |
-
|
| 53 |
-
This model was trained using the following Open Climate Fix packages:
|
| 54 |
-
|
| 55 |
-
- [1] https://github.com/openclimatefix/PVNet
|
| 56 |
-
- [2] https://github.com/openclimatefix/ocf-data-sampler
|
| 57 |
-
|
| 58 |
-
The versions of these packages can be found below:
|
| 59 |
-
{{ package_versions }}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/model_cards/wind_india_model_card_template.md
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
{{ card_data }}
|
| 3 |
-
---
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# WindNet
|
| 11 |
-
|
| 12 |
-
## Model Description
|
| 13 |
-
|
| 14 |
-
<!-- Provide a longer summary of what this model is/does. -->
|
| 15 |
-
This model class uses numerical weather predictions from providers such as ECMWF to forecast the wind power in North West India over the next 48 hours at 15 minute granularity. More information can be found in the model repo [1] and experimental notes [here](https://github.com/openclimatefix/PVNet/tree/main/experiments/india).
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
- **Developed by:** openclimatefix
|
| 19 |
-
- **Model type:** Fusion model
|
| 20 |
-
- **Language(s) (NLP):** en
|
| 21 |
-
- **License:** mit
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
# Training Details
|
| 25 |
-
|
| 26 |
-
## Data
|
| 27 |
-
|
| 28 |
-
<!-- This should link to a Data Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 29 |
-
|
| 30 |
-
The model is trained on data from 2019-2022 and validated on data from 2022-2023. See experimental notes [here](https://github.com/openclimatefix/PVNet/tree/main/experiments/india)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
### Preprocessing
|
| 34 |
-
|
| 35 |
-
Data is prepared with the `ocf_data_sampler/torch_datasets/datasets/site` Dataset [2].
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
## Results
|
| 39 |
-
|
| 40 |
-
The training logs for the current model can be found here:
|
| 41 |
-
{{ wandb_links }}
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
### Hardware
|
| 45 |
-
|
| 46 |
-
Trained on a single NVIDIA Tesla T4
|
| 47 |
-
|
| 48 |
-
### Software
|
| 49 |
-
|
| 50 |
-
This model was trained using the following Open Climate Fix packages:
|
| 51 |
-
|
| 52 |
-
- [1] https://github.com/openclimatefix/PVNet
|
| 53 |
-
- [2] https://github.com/openclimatefix/ocf-data-sampler
|
| 54 |
-
|
| 55 |
-
The versions of these packages can be found below:
|
| 56 |
-
{{ package_versions }}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/multimodal/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
"""Multimodal Models"""
|
|
|
|
|
|
pvnet/models/multimodal/basic_blocks.py
DELETED
|
@@ -1,104 +0,0 @@
|
|
| 1 |
-
"""Basic layers for composite models"""
|
| 2 |
-
|
| 3 |
-
import warnings
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
from torch import _VF, nn
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class ImageEmbedding(nn.Module):
|
| 10 |
-
"""A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs."""
|
| 11 |
-
|
| 12 |
-
def __init__(self, num_embeddings, sequence_length, image_size_pixels, **kwargs):
|
| 13 |
-
"""A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs.
|
| 14 |
-
|
| 15 |
-
The embedding is a single 2D image and is appended at each step in the 1st dimension
|
| 16 |
-
(assumed to be time).
|
| 17 |
-
|
| 18 |
-
Args:
|
| 19 |
-
num_embeddings: Size of the dictionary of embeddings
|
| 20 |
-
sequence_length: The time sequence length of the data.
|
| 21 |
-
image_size_pixels: The spatial size of the image. Assumed square.
|
| 22 |
-
**kwargs: See `torch.nn.Embedding` for more possible arguments.
|
| 23 |
-
"""
|
| 24 |
-
super().__init__()
|
| 25 |
-
self.image_size_pixels = image_size_pixels
|
| 26 |
-
self.sequence_length = sequence_length
|
| 27 |
-
self._embed = nn.Embedding(
|
| 28 |
-
num_embeddings=num_embeddings,
|
| 29 |
-
embedding_dim=image_size_pixels * image_size_pixels,
|
| 30 |
-
**kwargs,
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
def forward(self, x, id):
|
| 34 |
-
"""Append ID embedding to image"""
|
| 35 |
-
emb = self._embed(id)
|
| 36 |
-
emb = emb.reshape((-1, 1, 1, self.image_size_pixels, self.image_size_pixels))
|
| 37 |
-
emb = emb.repeat(1, 1, self.sequence_length, 1, 1)
|
| 38 |
-
x = torch.cat((x, emb), dim=1)
|
| 39 |
-
return x
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class CompleteDropoutNd(nn.Module):
|
| 43 |
-
"""A layer used to completely drop out all elements of a N-dimensional sample.
|
| 44 |
-
|
| 45 |
-
Each sample will be zeroed out independently on every forward call with probability `p` using
|
| 46 |
-
samples from a Bernoulli distribution.
|
| 47 |
-
|
| 48 |
-
"""
|
| 49 |
-
|
| 50 |
-
__constants__ = ["p", "inplace", "n_dim"]
|
| 51 |
-
p: float
|
| 52 |
-
inplace: bool
|
| 53 |
-
n_dim: int
|
| 54 |
-
|
| 55 |
-
def __init__(self, n_dim, p=0.5, inplace=False):
|
| 56 |
-
"""A layer used to completely drop out all elements of a N-dimensional sample.
|
| 57 |
-
|
| 58 |
-
Args:
|
| 59 |
-
n_dim: Number of dimensions of each sample not including channels. E.g. a sample with
|
| 60 |
-
shape (channel, time, height, width) would use `n_dim=3`.
|
| 61 |
-
p: probability of a channel to be zeroed. Default: 0.5
|
| 62 |
-
training: apply dropout if is `True`. Default: `True`
|
| 63 |
-
inplace: If set to `True`, will do this operation in-place. Default: `False`
|
| 64 |
-
"""
|
| 65 |
-
super().__init__()
|
| 66 |
-
if p < 0 or p > 1:
|
| 67 |
-
raise ValueError(
|
| 68 |
-
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
|
| 69 |
-
)
|
| 70 |
-
self.p = p
|
| 71 |
-
self.inplace = inplace
|
| 72 |
-
self.n_dim = n_dim
|
| 73 |
-
|
| 74 |
-
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 75 |
-
"""Run dropout"""
|
| 76 |
-
p = self.p
|
| 77 |
-
inp_dim = input.dim()
|
| 78 |
-
|
| 79 |
-
if inp_dim not in (self.n_dim + 1, self.n_dim + 2):
|
| 80 |
-
warn_msg = (
|
| 81 |
-
f"CompleteDropoutNd: Received a {inp_dim}-D input. Expected either a single sample"
|
| 82 |
-
f" with {self.n_dim+1} dimensions, or a batch of samples with {self.n_dim+2}"
|
| 83 |
-
" dimensions."
|
| 84 |
-
)
|
| 85 |
-
warnings.warn(warn_msg)
|
| 86 |
-
|
| 87 |
-
is_batched = inp_dim == self.n_dim + 2
|
| 88 |
-
if not is_batched:
|
| 89 |
-
input = input.unsqueeze_(0) if self.inplace else input.unsqueeze(0)
|
| 90 |
-
|
| 91 |
-
input = input.unsqueeze_(1) if self.inplace else input.unsqueeze(1)
|
| 92 |
-
|
| 93 |
-
result = (
|
| 94 |
-
_VF.feature_dropout_(input, p, self.training)
|
| 95 |
-
if self.inplace
|
| 96 |
-
else _VF.feature_dropout(input, p, self.training)
|
| 97 |
-
)
|
| 98 |
-
|
| 99 |
-
result = result.squeeze_(1) if self.inplace else result.squeeze(1)
|
| 100 |
-
|
| 101 |
-
if not is_batched:
|
| 102 |
-
result = result.squeeze_(0) if self.inplace else result.squeeze(0)
|
| 103 |
-
|
| 104 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/multimodal/encoders/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
"""Submodels to encode satellite and NWP inputs"""
|
|
|
|
|
|
pvnet/models/multimodal/encoders/basic_blocks.py
DELETED
|
@@ -1,217 +0,0 @@
|
|
| 1 |
-
"""Basic blocks for image sequence encoders"""
|
| 2 |
-
from abc import ABCMeta, abstractmethod
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta):
|
| 9 |
-
"""Abstract class for NWP/satellite encoder.
|
| 10 |
-
|
| 11 |
-
The encoder will take an input of shape (batch_size, sequence_length, channels, height, width)
|
| 12 |
-
and return an output of shape (batch_size, out_features).
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
def __init__(
|
| 16 |
-
self,
|
| 17 |
-
sequence_length: int,
|
| 18 |
-
image_size_pixels: int,
|
| 19 |
-
in_channels: int,
|
| 20 |
-
out_features: int,
|
| 21 |
-
):
|
| 22 |
-
"""Abstract class for NWP/satellite encoder.
|
| 23 |
-
|
| 24 |
-
Args:
|
| 25 |
-
sequence_length: The time sequence length of the data.
|
| 26 |
-
image_size_pixels: The spatial size of the image. Assumed square.
|
| 27 |
-
in_channels: Number of input channels.
|
| 28 |
-
out_features: Number of output features.
|
| 29 |
-
"""
|
| 30 |
-
super().__init__()
|
| 31 |
-
self.out_features = out_features
|
| 32 |
-
self.image_size_pixels = image_size_pixels
|
| 33 |
-
self.sequence_length = sequence_length
|
| 34 |
-
|
| 35 |
-
@abstractmethod
|
| 36 |
-
def forward(self):
|
| 37 |
-
"""Run model forward"""
|
| 38 |
-
pass
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class ResidualConv3dBlock(nn.Module):
|
| 42 |
-
"""Fully-connected deep network based on ResNet architecture.
|
| 43 |
-
|
| 44 |
-
Internally, this network uses ELU activations throughout the residual blocks.
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
-
def __init__(
|
| 48 |
-
self,
|
| 49 |
-
in_channels,
|
| 50 |
-
n_layers: int = 2,
|
| 51 |
-
dropout_frac: float = 0.0,
|
| 52 |
-
):
|
| 53 |
-
"""Fully-connected deep network based on ResNet architecture.
|
| 54 |
-
|
| 55 |
-
Args:
|
| 56 |
-
in_channels: Number of input channels.
|
| 57 |
-
n_layers: Number of layers in residual pathway.
|
| 58 |
-
dropout_frac: Probability of an element to be zeroed.
|
| 59 |
-
"""
|
| 60 |
-
super().__init__()
|
| 61 |
-
|
| 62 |
-
layers = []
|
| 63 |
-
for i in range(n_layers):
|
| 64 |
-
layers += [
|
| 65 |
-
nn.ELU(),
|
| 66 |
-
nn.Conv3d(
|
| 67 |
-
in_channels=in_channels,
|
| 68 |
-
out_channels=in_channels,
|
| 69 |
-
kernel_size=(3, 3, 3),
|
| 70 |
-
padding=(1, 1, 1),
|
| 71 |
-
),
|
| 72 |
-
nn.Dropout3d(p=dropout_frac),
|
| 73 |
-
]
|
| 74 |
-
|
| 75 |
-
self.model = nn.Sequential(*layers)
|
| 76 |
-
|
| 77 |
-
def forward(self, x):
|
| 78 |
-
"""Run residual connection"""
|
| 79 |
-
return self.model(x) + x
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
class ResidualConv3dBlock2(nn.Module):
|
| 83 |
-
"""Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1].
|
| 84 |
-
|
| 85 |
-
This was the best performing residual block tested in the study. This implementation differs
|
| 86 |
-
from that block just by using LeakyReLU activation to avoid dead neurons, and by including
|
| 87 |
-
optional dropout in the residual branch. This is also a 3D fully connected layer residual block
|
| 88 |
-
rather than a 2D convolutional block.
|
| 89 |
-
|
| 90 |
-
Sources:
|
| 91 |
-
[1] https://arxiv.org/pdf/1603.05027.pdf
|
| 92 |
-
"""
|
| 93 |
-
|
| 94 |
-
def __init__(
|
| 95 |
-
self,
|
| 96 |
-
in_channels: int,
|
| 97 |
-
n_layers: int = 2,
|
| 98 |
-
dropout_frac: float = 0.0,
|
| 99 |
-
batch_norm: bool = True,
|
| 100 |
-
):
|
| 101 |
-
"""Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1].
|
| 102 |
-
|
| 103 |
-
Sources:
|
| 104 |
-
[1] https://arxiv.org/pdf/1603.05027.pdf
|
| 105 |
-
|
| 106 |
-
Args:
|
| 107 |
-
in_channels: Number of input channels.
|
| 108 |
-
n_layers: Number of layers in residual pathway.
|
| 109 |
-
dropout_frac: Probability of an element to be zeroed.
|
| 110 |
-
batch_norm: Whether to use batchnorm
|
| 111 |
-
"""
|
| 112 |
-
super().__init__()
|
| 113 |
-
|
| 114 |
-
layers = []
|
| 115 |
-
for i in range(n_layers):
|
| 116 |
-
if batch_norm:
|
| 117 |
-
layers.append(nn.BatchNorm3d(in_channels))
|
| 118 |
-
layers.extend(
|
| 119 |
-
[
|
| 120 |
-
nn.Dropout3d(p=dropout_frac),
|
| 121 |
-
nn.LeakyReLU(),
|
| 122 |
-
nn.Conv3d(
|
| 123 |
-
in_channels=in_channels,
|
| 124 |
-
out_channels=in_channels,
|
| 125 |
-
kernel_size=(3, 3, 3),
|
| 126 |
-
padding=(1, 1, 1),
|
| 127 |
-
),
|
| 128 |
-
]
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
self.model = nn.Sequential(*layers)
|
| 132 |
-
|
| 133 |
-
def forward(self, x):
|
| 134 |
-
"""Run model forward"""
|
| 135 |
-
return self.model(x) + x
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
class ImageSequenceEncoder(nn.Module):
|
| 139 |
-
"""Simple network which independently encodes each image in a sequence into 1D features"""
|
| 140 |
-
|
| 141 |
-
def __init__(
|
| 142 |
-
self,
|
| 143 |
-
image_size_pixels: int,
|
| 144 |
-
in_channels: int,
|
| 145 |
-
number_of_conv2d_layers: int = 4,
|
| 146 |
-
conv2d_channels: int = 32,
|
| 147 |
-
fc_features: int = 128,
|
| 148 |
-
):
|
| 149 |
-
"""Simple network which independently encodes each image in a sequence into 1D features.
|
| 150 |
-
|
| 151 |
-
For input image with shape [N, C, L, H, W] the output is of shape [N, L, fc_features] where
|
| 152 |
-
N is number of samples in batch, C is the number of input channels, L is the length of the
|
| 153 |
-
sequence, and H and W are the height and width.
|
| 154 |
-
|
| 155 |
-
Args:
|
| 156 |
-
image_size_pixels: The spatial size of the image. Assumed square.
|
| 157 |
-
in_channels: Number of input channels.
|
| 158 |
-
number_of_conv2d_layers: Number of convolution 2D layers that are used.
|
| 159 |
-
conv2d_channels: Number of channels used in each conv2d layer.
|
| 160 |
-
fc_features: Number of output nodes for each image in each sequence.
|
| 161 |
-
"""
|
| 162 |
-
super().__init__()
|
| 163 |
-
|
| 164 |
-
# Check that the output shape of the convolutional layers will be at least 1x1
|
| 165 |
-
cnn_spatial_output_size = image_size_pixels - 2 * number_of_conv2d_layers
|
| 166 |
-
if not (cnn_spatial_output_size >= 1):
|
| 167 |
-
raise ValueError(
|
| 168 |
-
f"cannot use this many conv2d layers ({number_of_conv2d_layers}) with this input "
|
| 169 |
-
f"spatial size ({image_size_pixels})"
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
conv_layers = []
|
| 173 |
-
|
| 174 |
-
conv_layers += [
|
| 175 |
-
nn.Conv2d(
|
| 176 |
-
in_channels=in_channels,
|
| 177 |
-
out_channels=conv2d_channels,
|
| 178 |
-
kernel_size=3,
|
| 179 |
-
padding=0,
|
| 180 |
-
),
|
| 181 |
-
nn.ELU(),
|
| 182 |
-
]
|
| 183 |
-
for i in range(0, number_of_conv2d_layers - 1):
|
| 184 |
-
conv_layers += [
|
| 185 |
-
nn.Conv2d(
|
| 186 |
-
in_channels=conv2d_channels,
|
| 187 |
-
out_channels=conv2d_channels,
|
| 188 |
-
kernel_size=3,
|
| 189 |
-
padding=0,
|
| 190 |
-
),
|
| 191 |
-
nn.ELU(),
|
| 192 |
-
]
|
| 193 |
-
|
| 194 |
-
self.conv_layers = nn.Sequential(*conv_layers)
|
| 195 |
-
|
| 196 |
-
self.final_block = nn.Sequential(
|
| 197 |
-
nn.Linear(
|
| 198 |
-
in_features=(cnn_spatial_output_size**2) * conv2d_channels,
|
| 199 |
-
out_features=fc_features,
|
| 200 |
-
),
|
| 201 |
-
nn.ELU(),
|
| 202 |
-
)
|
| 203 |
-
|
| 204 |
-
def forward(self, x):
|
| 205 |
-
"""Run model forward"""
|
| 206 |
-
batch_size, channel, seq_len, height, width = x.shape
|
| 207 |
-
|
| 208 |
-
x = torch.swapaxes(x, 1, 2)
|
| 209 |
-
x = x.reshape(batch_size * seq_len, channel, height, width)
|
| 210 |
-
|
| 211 |
-
out = self.conv_layers(x)
|
| 212 |
-
out = out.reshape(batch_size * seq_len, -1)
|
| 213 |
-
|
| 214 |
-
out = self.final_block(out)
|
| 215 |
-
out = out.reshape(batch_size, seq_len, -1)
|
| 216 |
-
|
| 217 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/multimodal/encoders/encoders2d.py
DELETED
|
@@ -1,413 +0,0 @@
|
|
| 1 |
-
"""Encoder modules for the satellite/NWP data.
|
| 2 |
-
|
| 3 |
-
These networks naively stack the sequences into extra channels before putting through their
|
| 4 |
-
architectures.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from functools import partial
|
| 8 |
-
from typing import Any, Callable, List, Optional, Sequence, Type, Union
|
| 9 |
-
|
| 10 |
-
import torch
|
| 11 |
-
from torch import Tensor, nn
|
| 12 |
-
from torchvision.models.convnext import CNBlock, CNBlockConfig, LayerNorm2d
|
| 13 |
-
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
|
| 14 |
-
from torchvision.ops.misc import Conv2dNormActivation
|
| 15 |
-
from torchvision.utils import _log_api_usage_once
|
| 16 |
-
|
| 17 |
-
from pvnet.models.multimodal.encoders.basic_blocks import AbstractNWPSatelliteEncoder
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class NaiveEfficientNet(AbstractNWPSatelliteEncoder):
|
| 21 |
-
"""An implementation of EfficientNet from `efficientnet_pytorch`.
|
| 22 |
-
|
| 23 |
-
This model is quite naive, and just stacks the sequence into channels.
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
def __init__(
|
| 27 |
-
self,
|
| 28 |
-
sequence_length: int,
|
| 29 |
-
image_size_pixels: int,
|
| 30 |
-
in_channels: int,
|
| 31 |
-
out_features: int,
|
| 32 |
-
model_name: str = "efficientnet-b0",
|
| 33 |
-
):
|
| 34 |
-
"""An implementation of EfficientNet from `efficientnet_pytorch`.
|
| 35 |
-
|
| 36 |
-
This model is quite naive, and just stacks the sequence into channels.
|
| 37 |
-
|
| 38 |
-
Args:
|
| 39 |
-
sequence_length: The time sequence length of the data.
|
| 40 |
-
image_size_pixels: The spatial size of the image. Assumed square.
|
| 41 |
-
in_channels: Number of input channels.
|
| 42 |
-
out_features: Number of output features.
|
| 43 |
-
model_name: Name of EfficientNet model to construct.
|
| 44 |
-
|
| 45 |
-
Notes:
|
| 46 |
-
The `efficientnet_pytorch` package must be installed to use `EncoderNaiveEfficientNet`.
|
| 47 |
-
See https://github.com/lukemelas/EfficientNet-PyTorch for install instructions.
|
| 48 |
-
"""
|
| 49 |
-
|
| 50 |
-
from efficientnet_pytorch import EfficientNet
|
| 51 |
-
|
| 52 |
-
super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
|
| 53 |
-
|
| 54 |
-
self.model = EfficientNet.from_name(
|
| 55 |
-
model_name,
|
| 56 |
-
in_channels=in_channels * sequence_length,
|
| 57 |
-
image_size=image_size_pixels,
|
| 58 |
-
num_classes=out_features,
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
-
def forward(self, x):
|
| 62 |
-
"""Run model forward"""
|
| 63 |
-
bs, s, c, h, w = x.shape
|
| 64 |
-
x = x.reshape((bs, s * c, h, w))
|
| 65 |
-
return self.model(x)
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
class NaiveResNet(nn.Module):
|
| 69 |
-
"""A ResNet model modified from one in torchvision [1].
|
| 70 |
-
|
| 71 |
-
Modified allow different number of input channels. This model is quite naive, and just stacks
|
| 72 |
-
the sequence into channels.
|
| 73 |
-
|
| 74 |
-
Example use:
|
| 75 |
-
```
|
| 76 |
-
resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
|
| 77 |
-
resnet50 = ResNet(Bottleneck, [3, 4, 6, 3])
|
| 78 |
-
```
|
| 79 |
-
|
| 80 |
-
Sources:
|
| 81 |
-
[1] https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
|
| 82 |
-
[2] https://pytorch.org/hub/pytorch_vision_resnet
|
| 83 |
-
"""
|
| 84 |
-
|
| 85 |
-
def __init__(
|
| 86 |
-
self,
|
| 87 |
-
sequence_length: int,
|
| 88 |
-
image_size_pixels: int,
|
| 89 |
-
in_channels: int,
|
| 90 |
-
out_features: int,
|
| 91 |
-
layers: List[int] = [2, 2, 2, 2],
|
| 92 |
-
block: str = "bottleneck",
|
| 93 |
-
zero_init_residual: bool = False,
|
| 94 |
-
groups: int = 1,
|
| 95 |
-
width_per_group: int = 64,
|
| 96 |
-
replace_stride_with_dilation: Optional[List[bool]] = None,
|
| 97 |
-
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 98 |
-
):
|
| 99 |
-
"""A ResNet model modified from one in torchvision [1].
|
| 100 |
-
|
| 101 |
-
Args:
|
| 102 |
-
sequence_length: The time sequence length of the data.
|
| 103 |
-
image_size_pixels: The spatial size of the image. Assumed square.
|
| 104 |
-
in_channels: Number of input channels.
|
| 105 |
-
out_features: Number of output features.
|
| 106 |
-
layers: See [1] and [2].
|
| 107 |
-
block: See [1] and [2].
|
| 108 |
-
zero_init_residual: See [1] and [2].
|
| 109 |
-
groups: See [1] and [2].
|
| 110 |
-
width_per_group: See [1] and [2].
|
| 111 |
-
replace_stride_with_dilation: See [1] and [2].
|
| 112 |
-
norm_layer: See [1] and [2].
|
| 113 |
-
|
| 114 |
-
Sources:
|
| 115 |
-
[1] https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
|
| 116 |
-
[2] https://pytorch.org/hub/pytorch_vision_resnet
|
| 117 |
-
"""
|
| 118 |
-
super().__init__()
|
| 119 |
-
_log_api_usage_once(self)
|
| 120 |
-
if norm_layer is None:
|
| 121 |
-
norm_layer = nn.BatchNorm2d
|
| 122 |
-
self._norm_layer = norm_layer
|
| 123 |
-
|
| 124 |
-
# Account for stacking sequences into more channels
|
| 125 |
-
in_channels = in_channels * sequence_length
|
| 126 |
-
|
| 127 |
-
block = {
|
| 128 |
-
"basic": BasicBlock,
|
| 129 |
-
"bottleneck": Bottleneck,
|
| 130 |
-
}[block]
|
| 131 |
-
|
| 132 |
-
self.inplanes = 64
|
| 133 |
-
self.dilation = 1
|
| 134 |
-
if replace_stride_with_dilation is None:
|
| 135 |
-
# each element in the tuple indicates if we should replace
|
| 136 |
-
# the 2x2 stride with a dilated convolution instead
|
| 137 |
-
replace_stride_with_dilation = [False, False, False]
|
| 138 |
-
if len(replace_stride_with_dilation) != 3:
|
| 139 |
-
raise ValueError(
|
| 140 |
-
"replace_stride_with_dilation should be None "
|
| 141 |
-
f"or a 3-element tuple, got {replace_stride_with_dilation}"
|
| 142 |
-
)
|
| 143 |
-
self.groups = groups
|
| 144 |
-
self.base_width = width_per_group
|
| 145 |
-
self.conv1 = nn.Conv2d(
|
| 146 |
-
in_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
|
| 147 |
-
)
|
| 148 |
-
self.bn1 = norm_layer(self.inplanes)
|
| 149 |
-
self.relu = nn.ReLU(inplace=True)
|
| 150 |
-
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 151 |
-
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 152 |
-
self.layer2 = self._make_layer(
|
| 153 |
-
block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
|
| 154 |
-
)
|
| 155 |
-
self.layer3 = self._make_layer(
|
| 156 |
-
block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
|
| 157 |
-
)
|
| 158 |
-
self.layer4 = self._make_layer(
|
| 159 |
-
block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
|
| 160 |
-
)
|
| 161 |
-
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 162 |
-
self.fc = nn.Linear(512 * block.expansion, out_features)
|
| 163 |
-
self.final_act = nn.LeakyReLU()
|
| 164 |
-
|
| 165 |
-
for m in self.modules():
|
| 166 |
-
if isinstance(m, nn.Conv2d):
|
| 167 |
-
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 168 |
-
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 169 |
-
nn.init.constant_(m.weight, 1)
|
| 170 |
-
nn.init.constant_(m.bias, 0)
|
| 171 |
-
|
| 172 |
-
# Zero-initialize the last BN in each residual branch,
|
| 173 |
-
# so that the residual branch starts with zeros, and each residual block behaves like an
|
| 174 |
-
# identity. This improves the model by 0.2~0.3% according to
|
| 175 |
-
# https://arxiv.org/abs/1706.02677
|
| 176 |
-
if zero_init_residual:
|
| 177 |
-
for m in self.modules():
|
| 178 |
-
if isinstance(m, Bottleneck) and m.bn3.weight is not None:
|
| 179 |
-
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
| 180 |
-
elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
|
| 181 |
-
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
| 182 |
-
|
| 183 |
-
def _make_layer(
|
| 184 |
-
self,
|
| 185 |
-
block: Type[Union[BasicBlock, Bottleneck]],
|
| 186 |
-
planes: int,
|
| 187 |
-
blocks: int,
|
| 188 |
-
stride: int = 1,
|
| 189 |
-
dilate: bool = False,
|
| 190 |
-
) -> nn.Sequential:
|
| 191 |
-
norm_layer = self._norm_layer
|
| 192 |
-
downsample = None
|
| 193 |
-
previous_dilation = self.dilation
|
| 194 |
-
if dilate:
|
| 195 |
-
self.dilation *= stride
|
| 196 |
-
stride = 1
|
| 197 |
-
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 198 |
-
downsample = nn.Sequential(
|
| 199 |
-
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 200 |
-
norm_layer(planes * block.expansion),
|
| 201 |
-
)
|
| 202 |
-
|
| 203 |
-
layers = []
|
| 204 |
-
layers.append(
|
| 205 |
-
block(
|
| 206 |
-
self.inplanes,
|
| 207 |
-
planes,
|
| 208 |
-
stride,
|
| 209 |
-
downsample,
|
| 210 |
-
self.groups,
|
| 211 |
-
self.base_width,
|
| 212 |
-
previous_dilation,
|
| 213 |
-
norm_layer,
|
| 214 |
-
)
|
| 215 |
-
)
|
| 216 |
-
self.inplanes = planes * block.expansion
|
| 217 |
-
for _ in range(1, blocks):
|
| 218 |
-
layers.append(
|
| 219 |
-
block(
|
| 220 |
-
self.inplanes,
|
| 221 |
-
planes,
|
| 222 |
-
groups=self.groups,
|
| 223 |
-
base_width=self.base_width,
|
| 224 |
-
dilation=self.dilation,
|
| 225 |
-
norm_layer=norm_layer,
|
| 226 |
-
)
|
| 227 |
-
)
|
| 228 |
-
|
| 229 |
-
return nn.Sequential(*layers)
|
| 230 |
-
|
| 231 |
-
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 232 |
-
# See note [TorchScript super()]
|
| 233 |
-
x = self.conv1(x)
|
| 234 |
-
x = self.bn1(x)
|
| 235 |
-
x = self.relu(x)
|
| 236 |
-
# x = self.maxpool(x)
|
| 237 |
-
|
| 238 |
-
x = self.layer1(x)
|
| 239 |
-
x = self.layer2(x)
|
| 240 |
-
x = self.layer3(x)
|
| 241 |
-
x = self.layer4(x)
|
| 242 |
-
|
| 243 |
-
x = self.avgpool(x)
|
| 244 |
-
x = torch.flatten(x, 1)
|
| 245 |
-
x = self.fc(x)
|
| 246 |
-
x = self.final_act(x)
|
| 247 |
-
|
| 248 |
-
return x
|
| 249 |
-
|
| 250 |
-
def forward(self, x: Tensor) -> Tensor:
|
| 251 |
-
"""Run model forward"""
|
| 252 |
-
bs, s, c, h, w = x.shape
|
| 253 |
-
x = x.reshape((bs, s * c, h, w))
|
| 254 |
-
return self._forward_impl(x)
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
class NaiveConvNeXt(nn.Module):
|
| 258 |
-
"""A NaiveConvNeXt model [1] modified from one in torchvision [2].
|
| 259 |
-
|
| 260 |
-
Mopdified to allow different number of input channels, and smaller spatial inputs. This model is
|
| 261 |
-
quite naive, and just stacks the sequence into channels.
|
| 262 |
-
|
| 263 |
-
Example usage:
|
| 264 |
-
```
|
| 265 |
-
block_setting = [
|
| 266 |
-
CNBlockConfig(96, 192, 3),
|
| 267 |
-
CNBlockConfig(192, 384, 3),
|
| 268 |
-
CNBlockConfig(384, 768, 9),
|
| 269 |
-
CNBlockConfig(768, None, 3),
|
| 270 |
-
]
|
| 271 |
-
|
| 272 |
-
sequence_len = 12
|
| 273 |
-
channels = 2
|
| 274 |
-
pixels=24
|
| 275 |
-
|
| 276 |
-
convnext_tiny = ConvNeXt(
|
| 277 |
-
sequence_length=12,
|
| 278 |
-
image_size_pixels=24,
|
| 279 |
-
in_channels=2,
|
| 280 |
-
out_features=128,
|
| 281 |
-
block_setting=block_setting,
|
| 282 |
-
stochastic_depth_prob=0.1,
|
| 283 |
-
)
|
| 284 |
-
```
|
| 285 |
-
|
| 286 |
-
Sources:
|
| 287 |
-
[1] https://arxiv.org/abs/2201.03545
|
| 288 |
-
[2] https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py
|
| 289 |
-
[3] https://pytorch.org/vision/main/models/convnext.html
|
| 290 |
-
|
| 291 |
-
"""
|
| 292 |
-
|
| 293 |
-
def __init__(
|
| 294 |
-
self,
|
| 295 |
-
sequence_length: int,
|
| 296 |
-
image_size_pixels: int,
|
| 297 |
-
in_channels: int,
|
| 298 |
-
out_features: int,
|
| 299 |
-
block_setting: List[CNBlockConfig],
|
| 300 |
-
stochastic_depth_prob: float = 0.0,
|
| 301 |
-
layer_scale: float = 1e-6,
|
| 302 |
-
block: Optional[Callable[..., nn.Module]] = None,
|
| 303 |
-
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 304 |
-
**kwargs: Any,
|
| 305 |
-
) -> None:
|
| 306 |
-
"""A ConvNeXt model [1] modified from one in torchvision [2].
|
| 307 |
-
|
| 308 |
-
Args:
|
| 309 |
-
sequence_length: The time sequence length of the data.
|
| 310 |
-
image_size_pixels: The spatial size of the image. Assumed square.
|
| 311 |
-
in_channels: Number of input channels.
|
| 312 |
-
out_features: Number of output features.
|
| 313 |
-
block_setting: See [2] and [3].
|
| 314 |
-
stochastic_depth_prob: See [2] and [3].
|
| 315 |
-
layer_scale: See [2] and [3].
|
| 316 |
-
block: See [2] and [3].
|
| 317 |
-
norm_layer: See [2] and [3].
|
| 318 |
-
**kwargs: See [2] and [3].
|
| 319 |
-
|
| 320 |
-
Sources:
|
| 321 |
-
[1] https://arxiv.org/abs/2201.03545
|
| 322 |
-
[2] https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py
|
| 323 |
-
[3] https://pytorch.org/vision/main/models/convnext.html
|
| 324 |
-
"""
|
| 325 |
-
super().__init__()
|
| 326 |
-
_log_api_usage_once(self)
|
| 327 |
-
|
| 328 |
-
if not block_setting:
|
| 329 |
-
raise ValueError("The block_setting should not be empty")
|
| 330 |
-
elif not (
|
| 331 |
-
isinstance(block_setting, Sequence)
|
| 332 |
-
and all([isinstance(s, CNBlockConfig) for s in block_setting])
|
| 333 |
-
):
|
| 334 |
-
raise TypeError("The block_setting should be List[CNBlockConfig]")
|
| 335 |
-
|
| 336 |
-
if block is None:
|
| 337 |
-
block = CNBlock
|
| 338 |
-
|
| 339 |
-
if norm_layer is None:
|
| 340 |
-
norm_layer = partial(LayerNorm2d, eps=1e-6)
|
| 341 |
-
|
| 342 |
-
layers: List[nn.Module] = []
|
| 343 |
-
|
| 344 |
-
# Account for stacking sequences into more channels
|
| 345 |
-
in_channels = in_channels * sequence_length
|
| 346 |
-
|
| 347 |
-
# Stem
|
| 348 |
-
firstconv_output_channels = block_setting[0].input_channels
|
| 349 |
-
layers.append(
|
| 350 |
-
Conv2dNormActivation(
|
| 351 |
-
in_channels,
|
| 352 |
-
firstconv_output_channels,
|
| 353 |
-
kernel_size=2,
|
| 354 |
-
stride=2,
|
| 355 |
-
padding=0,
|
| 356 |
-
norm_layer=norm_layer,
|
| 357 |
-
activation_layer=None,
|
| 358 |
-
bias=True,
|
| 359 |
-
)
|
| 360 |
-
)
|
| 361 |
-
|
| 362 |
-
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
|
| 363 |
-
stage_block_id = 0
|
| 364 |
-
for cnf in block_setting:
|
| 365 |
-
# Bottlenecks
|
| 366 |
-
stage: List[nn.Module] = []
|
| 367 |
-
for _ in range(cnf.num_layers):
|
| 368 |
-
# adjust stochastic depth probability based on the depth of the stage block
|
| 369 |
-
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
|
| 370 |
-
stage.append(block(cnf.input_channels, layer_scale, sd_prob))
|
| 371 |
-
stage_block_id += 1
|
| 372 |
-
layers.append(nn.Sequential(*stage))
|
| 373 |
-
if cnf.out_channels is not None:
|
| 374 |
-
# Downsampling
|
| 375 |
-
layers.append(
|
| 376 |
-
nn.Sequential(
|
| 377 |
-
norm_layer(cnf.input_channels),
|
| 378 |
-
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
|
| 379 |
-
)
|
| 380 |
-
)
|
| 381 |
-
|
| 382 |
-
self.features = nn.Sequential(*layers)
|
| 383 |
-
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
| 384 |
-
|
| 385 |
-
lastblock = block_setting[-1]
|
| 386 |
-
lastconv_output_channels = (
|
| 387 |
-
lastblock.out_channels
|
| 388 |
-
if lastblock.out_channels is not None
|
| 389 |
-
else lastblock.input_channels
|
| 390 |
-
)
|
| 391 |
-
self.classifier = nn.Sequential(
|
| 392 |
-
norm_layer(lastconv_output_channels),
|
| 393 |
-
nn.Flatten(1),
|
| 394 |
-
nn.Linear(lastconv_output_channels, out_features),
|
| 395 |
-
)
|
| 396 |
-
|
| 397 |
-
for m in self.modules():
|
| 398 |
-
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 399 |
-
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 400 |
-
if m.bias is not None:
|
| 401 |
-
nn.init.zeros_(m.bias)
|
| 402 |
-
|
| 403 |
-
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 404 |
-
x = self.features(x)
|
| 405 |
-
x = self.avgpool(x)
|
| 406 |
-
x = self.classifier(x)
|
| 407 |
-
return x
|
| 408 |
-
|
| 409 |
-
def forward(self, x: Tensor) -> Tensor:
|
| 410 |
-
"""Run model forward"""
|
| 411 |
-
bs, s, c, h, w = x.shape
|
| 412 |
-
x = x.reshape((bs, s * c, h, w))
|
| 413 |
-
return self._forward_impl(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/multimodal/encoders/encoders3d.py
DELETED
|
@@ -1,402 +0,0 @@
|
|
| 1 |
-
"""Encoder modules for the satellite/NWP data based on 3D concolutions.
|
| 2 |
-
"""
|
| 3 |
-
from typing import List, Union
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
from torch import nn
|
| 7 |
-
from torchvision.transforms import CenterCrop
|
| 8 |
-
|
| 9 |
-
from pvnet.models.multimodal.encoders.basic_blocks import (
|
| 10 |
-
AbstractNWPSatelliteEncoder,
|
| 11 |
-
ResidualConv3dBlock,
|
| 12 |
-
ResidualConv3dBlock2,
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class DefaultPVNet(AbstractNWPSatelliteEncoder):
|
| 17 |
-
"""This is the original encoding module used in PVNet, with a few minor tweaks."""
|
| 18 |
-
|
| 19 |
-
def __init__(
|
| 20 |
-
self,
|
| 21 |
-
sequence_length: int,
|
| 22 |
-
image_size_pixels: int,
|
| 23 |
-
in_channels: int,
|
| 24 |
-
out_features: int,
|
| 25 |
-
number_of_conv3d_layers: int = 4,
|
| 26 |
-
conv3d_channels: int = 32,
|
| 27 |
-
fc_features: int = 128,
|
| 28 |
-
spatial_kernel_size: int = 3,
|
| 29 |
-
temporal_kernel_size: int = 3,
|
| 30 |
-
padding: Union[int, List[int]] = (1, 0, 0),
|
| 31 |
-
):
|
| 32 |
-
"""This is the original encoding module used in PVNet, with a few minor tweaks.
|
| 33 |
-
|
| 34 |
-
Args:
|
| 35 |
-
sequence_length: The time sequence length of the data.
|
| 36 |
-
image_size_pixels: The spatial size of the image. Assumed square.
|
| 37 |
-
in_channels: Number of input channels.
|
| 38 |
-
out_features: Number of output features.
|
| 39 |
-
number_of_conv3d_layers: Number of convolution 3d layers that are used.
|
| 40 |
-
conv3d_channels: Number of channels used in each conv3d layer.
|
| 41 |
-
fc_features: number of output nodes out of the hidden fully connected layer.
|
| 42 |
-
spatial_kernel_size: The spatial size of the kernel used in the conv3d layers.
|
| 43 |
-
temporal_kernel_size: The temporal size of the kernel used in the conv3d layers.
|
| 44 |
-
padding: The padding used in the conv3d layers. If an int, the same padding
|
| 45 |
-
is used in all dimensions
|
| 46 |
-
"""
|
| 47 |
-
super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
|
| 48 |
-
if isinstance(padding, int):
|
| 49 |
-
padding = (padding, padding, padding)
|
| 50 |
-
# Check that the output shape of the convolutional layers will be at least 1x1
|
| 51 |
-
cnn_spatial_output_size = (
|
| 52 |
-
image_size_pixels
|
| 53 |
-
- ((spatial_kernel_size - 2 * padding[1]) - 1) * number_of_conv3d_layers
|
| 54 |
-
)
|
| 55 |
-
cnn_sequence_length = (
|
| 56 |
-
sequence_length
|
| 57 |
-
- ((temporal_kernel_size - 2 * padding[0]) - 1) * number_of_conv3d_layers
|
| 58 |
-
)
|
| 59 |
-
if not (cnn_spatial_output_size >= 1):
|
| 60 |
-
raise ValueError(
|
| 61 |
-
f"cannot use this many conv3d layers ({number_of_conv3d_layers}) with this input "
|
| 62 |
-
f"spatial size ({image_size_pixels})"
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
conv_layers = []
|
| 66 |
-
|
| 67 |
-
conv_layers += [
|
| 68 |
-
nn.Conv3d(
|
| 69 |
-
in_channels=in_channels,
|
| 70 |
-
out_channels=conv3d_channels,
|
| 71 |
-
kernel_size=(temporal_kernel_size, spatial_kernel_size, spatial_kernel_size),
|
| 72 |
-
padding=padding,
|
| 73 |
-
),
|
| 74 |
-
nn.ELU(),
|
| 75 |
-
]
|
| 76 |
-
for i in range(0, number_of_conv3d_layers - 1):
|
| 77 |
-
conv_layers += [
|
| 78 |
-
nn.Conv3d(
|
| 79 |
-
in_channels=conv3d_channels,
|
| 80 |
-
out_channels=conv3d_channels,
|
| 81 |
-
kernel_size=(temporal_kernel_size, spatial_kernel_size, spatial_kernel_size),
|
| 82 |
-
padding=padding,
|
| 83 |
-
),
|
| 84 |
-
nn.ELU(),
|
| 85 |
-
]
|
| 86 |
-
|
| 87 |
-
self.conv_layers = nn.Sequential(*conv_layers)
|
| 88 |
-
|
| 89 |
-
# Calculate the size of the output of the 3D convolutional layers
|
| 90 |
-
cnn_output_size = conv3d_channels * cnn_spatial_output_size**2 * cnn_sequence_length
|
| 91 |
-
|
| 92 |
-
self.final_block = nn.Sequential(
|
| 93 |
-
nn.Linear(in_features=cnn_output_size, out_features=fc_features),
|
| 94 |
-
nn.ELU(),
|
| 95 |
-
nn.Linear(in_features=fc_features, out_features=out_features),
|
| 96 |
-
nn.ELU(),
|
| 97 |
-
)
|
| 98 |
-
|
| 99 |
-
def forward(self, x):
|
| 100 |
-
"""Run model forward"""
|
| 101 |
-
out = self.conv_layers(x)
|
| 102 |
-
out = out.reshape(x.shape[0], -1)
|
| 103 |
-
|
| 104 |
-
# Fully connected layers
|
| 105 |
-
out = self.final_block(out)
|
| 106 |
-
|
| 107 |
-
return out
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
class DefaultPVNet2(AbstractNWPSatelliteEncoder):
|
| 111 |
-
"""The original encoding module used in PVNet, with a few minor tweaks, and batchnorm."""
|
| 112 |
-
|
| 113 |
-
def __init__(
|
| 114 |
-
self,
|
| 115 |
-
sequence_length: int,
|
| 116 |
-
image_size_pixels: int,
|
| 117 |
-
in_channels: int,
|
| 118 |
-
out_features: int,
|
| 119 |
-
number_of_conv3d_layers: int = 4,
|
| 120 |
-
conv3d_channels: int = 32,
|
| 121 |
-
fc_features: int = 128,
|
| 122 |
-
batch_norm=True,
|
| 123 |
-
fc_dropout=0.2,
|
| 124 |
-
):
|
| 125 |
-
"""The original encoding module used in PVNet, with a few minor tweaks, and batchnorm.
|
| 126 |
-
|
| 127 |
-
Args:
|
| 128 |
-
sequence_length: The time sequence length of the data.
|
| 129 |
-
image_size_pixels: The spatial size of the image. Assumed square.
|
| 130 |
-
in_channels: Number of input channels.
|
| 131 |
-
out_features: Number of output features.
|
| 132 |
-
number_of_conv3d_layers: Number of convolution 3d layers that are used.
|
| 133 |
-
conv3d_channels: Number of channels used in each conv3d layer.
|
| 134 |
-
fc_features: number of output nodes out of the hidden fully connected layer.
|
| 135 |
-
batch_norm: Whether to include 3D batch normalisation.
|
| 136 |
-
fc_dropout: Probability of an element to be zeroed before the last two fully connected
|
| 137 |
-
layers.
|
| 138 |
-
"""
|
| 139 |
-
super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
|
| 140 |
-
|
| 141 |
-
# Check that the output shape of the convolutional layers will be at least 1x1
|
| 142 |
-
cnn_spatial_output_size = image_size_pixels - 2 * number_of_conv3d_layers
|
| 143 |
-
if not (cnn_spatial_output_size > 0):
|
| 144 |
-
raise ValueError(
|
| 145 |
-
f"cannot use this many conv3d layers ({number_of_conv3d_layers}) with this input "
|
| 146 |
-
f"spatial size ({image_size_pixels})"
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
conv_layers = [
|
| 150 |
-
nn.Conv3d(
|
| 151 |
-
in_channels=in_channels,
|
| 152 |
-
out_channels=conv3d_channels,
|
| 153 |
-
kernel_size=(3, 3, 3),
|
| 154 |
-
padding=(1, 0, 0),
|
| 155 |
-
),
|
| 156 |
-
nn.LeakyReLU(),
|
| 157 |
-
]
|
| 158 |
-
if batch_norm:
|
| 159 |
-
# Inserted before activation using position -1
|
| 160 |
-
conv_layers.insert(-1, nn.BatchNorm3d(conv3d_channels))
|
| 161 |
-
for i in range(0, number_of_conv3d_layers - 1):
|
| 162 |
-
conv_layers += [
|
| 163 |
-
nn.Conv3d(
|
| 164 |
-
in_channels=conv3d_channels,
|
| 165 |
-
out_channels=conv3d_channels,
|
| 166 |
-
kernel_size=(3, 3, 3),
|
| 167 |
-
padding=(1, 0, 0),
|
| 168 |
-
),
|
| 169 |
-
nn.LeakyReLU(),
|
| 170 |
-
]
|
| 171 |
-
if batch_norm:
|
| 172 |
-
# Inserted before activation using position -1
|
| 173 |
-
conv_layers.insert(-1, nn.BatchNorm3d(conv3d_channels))
|
| 174 |
-
|
| 175 |
-
self.conv_layers = nn.Sequential(*conv_layers)
|
| 176 |
-
|
| 177 |
-
# Calculate the size of the output of the 3D convolutional layers
|
| 178 |
-
cnn_output_size = conv3d_channels * cnn_spatial_output_size**2 * sequence_length
|
| 179 |
-
|
| 180 |
-
final_block = [
|
| 181 |
-
nn.Linear(in_features=cnn_output_size, out_features=fc_features),
|
| 182 |
-
nn.LeakyReLU(),
|
| 183 |
-
nn.Linear(in_features=fc_features, out_features=out_features),
|
| 184 |
-
nn.LeakyReLU(),
|
| 185 |
-
]
|
| 186 |
-
|
| 187 |
-
if fc_dropout > 0:
|
| 188 |
-
# Insert after the linear layers
|
| 189 |
-
final_block.insert(1, nn.Dropout(fc_dropout))
|
| 190 |
-
final_block.insert(-1, nn.Dropout(fc_dropout))
|
| 191 |
-
|
| 192 |
-
self.final_block = nn.Sequential(*final_block)
|
| 193 |
-
|
| 194 |
-
def forward(self, x):
|
| 195 |
-
"""Run model forward"""
|
| 196 |
-
out = self.conv_layers(x)
|
| 197 |
-
out = out.reshape(x.shape[0], -1)
|
| 198 |
-
|
| 199 |
-
# Fully connected layers
|
| 200 |
-
out = self.final_block(out)
|
| 201 |
-
|
| 202 |
-
return out
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
class ResConv3DNet2(AbstractNWPSatelliteEncoder):
|
| 206 |
-
"""3D convolutional network based on ResNet architecture.
|
| 207 |
-
|
| 208 |
-
The residual blocks are implemented based on the best performing block in [1].
|
| 209 |
-
|
| 210 |
-
Sources:
|
| 211 |
-
[1] https://arxiv.org/pdf/1603.05027.pdf
|
| 212 |
-
"""
|
| 213 |
-
|
| 214 |
-
def __init__(
|
| 215 |
-
self,
|
| 216 |
-
sequence_length: int,
|
| 217 |
-
image_size_pixels: int,
|
| 218 |
-
in_channels: int,
|
| 219 |
-
out_features: int,
|
| 220 |
-
hidden_channels: int = 32,
|
| 221 |
-
n_res_blocks: int = 4,
|
| 222 |
-
res_block_layers: int = 2,
|
| 223 |
-
batch_norm=True,
|
| 224 |
-
dropout_frac=0.0,
|
| 225 |
-
):
|
| 226 |
-
"""Fully connected deep network based on ResNet architecture.
|
| 227 |
-
|
| 228 |
-
Args:
|
| 229 |
-
sequence_length: The time sequence length of the data.
|
| 230 |
-
image_size_pixels: The spatial size of the image. Assumed square.
|
| 231 |
-
in_channels: Number of input channels.
|
| 232 |
-
out_features: Number of output features.
|
| 233 |
-
hidden_channels: Number of channels in middle hidden layers.
|
| 234 |
-
n_res_blocks: Number of residual blocks to use.
|
| 235 |
-
res_block_layers: Number of Conv3D layers used in each residual block.
|
| 236 |
-
batch_norm: Whether to include batch normalisation.
|
| 237 |
-
dropout_frac: Probability of an element to be zeroed in the residual pathways.
|
| 238 |
-
"""
|
| 239 |
-
super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
|
| 240 |
-
|
| 241 |
-
model = [
|
| 242 |
-
nn.Conv3d(
|
| 243 |
-
in_channels=in_channels,
|
| 244 |
-
out_channels=hidden_channels,
|
| 245 |
-
kernel_size=(3, 3, 3),
|
| 246 |
-
padding=(1, 1, 1),
|
| 247 |
-
),
|
| 248 |
-
]
|
| 249 |
-
|
| 250 |
-
for i in range(n_res_blocks):
|
| 251 |
-
model.extend(
|
| 252 |
-
[
|
| 253 |
-
ResidualConv3dBlock2(
|
| 254 |
-
in_channels=hidden_channels,
|
| 255 |
-
n_layers=res_block_layers,
|
| 256 |
-
dropout_frac=dropout_frac,
|
| 257 |
-
batch_norm=batch_norm,
|
| 258 |
-
),
|
| 259 |
-
nn.AvgPool3d((1, 2, 2), stride=(1, 2, 2)),
|
| 260 |
-
]
|
| 261 |
-
)
|
| 262 |
-
|
| 263 |
-
# Calculate the size of the output of the 3D convolutional layers
|
| 264 |
-
final_im_size = image_size_pixels // (2**n_res_blocks)
|
| 265 |
-
cnn_output_size = hidden_channels * sequence_length * final_im_size * final_im_size
|
| 266 |
-
|
| 267 |
-
model.extend(
|
| 268 |
-
[
|
| 269 |
-
nn.ELU(),
|
| 270 |
-
nn.Flatten(start_dim=1, end_dim=-1),
|
| 271 |
-
nn.Linear(in_features=cnn_output_size, out_features=out_features),
|
| 272 |
-
nn.ELU(),
|
| 273 |
-
]
|
| 274 |
-
)
|
| 275 |
-
|
| 276 |
-
self.model = nn.Sequential(*model)
|
| 277 |
-
|
| 278 |
-
def forward(self, x):
|
| 279 |
-
"""Run model forward"""
|
| 280 |
-
return self.model(x)
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
class EncoderUNET(AbstractNWPSatelliteEncoder):
|
| 284 |
-
"""An encoder based on emodifed UNet architecture.
|
| 285 |
-
|
| 286 |
-
An encoder for satellite and/or NWP data taking inspiration from the kinds of skip
|
| 287 |
-
connections in UNet. This differs from an actual UNet in that it does not have upsampling
|
| 288 |
-
layers, instead it concats features from different spatial scales, and applies a few extra
|
| 289 |
-
conv3d layers.
|
| 290 |
-
"""
|
| 291 |
-
|
| 292 |
-
def __init__(
|
| 293 |
-
self,
|
| 294 |
-
sequence_length: int,
|
| 295 |
-
image_size_pixels: int,
|
| 296 |
-
in_channels: int,
|
| 297 |
-
out_features: int,
|
| 298 |
-
n_downscale: int = 3,
|
| 299 |
-
res_block_layers: int = 2,
|
| 300 |
-
conv3d_channels: int = 32,
|
| 301 |
-
dropout_frac: float = 0.1,
|
| 302 |
-
):
|
| 303 |
-
"""An encoder based on emodifed UNet architecture.
|
| 304 |
-
|
| 305 |
-
Args:
|
| 306 |
-
sequence_length: The time sequence length of the data.
|
| 307 |
-
image_size_pixels: The spatial size of the image. Assumed square.
|
| 308 |
-
in_channels: Number of input channels.
|
| 309 |
-
out_features: Number of output features.
|
| 310 |
-
n_downscale: Number of conv3d and spatially downscaling layers that are used.
|
| 311 |
-
res_block_layers: Number of residual blocks used after each downscale layer.
|
| 312 |
-
conv3d_channels: Number of channels used in each conv3d layer.
|
| 313 |
-
dropout_frac: Probability of an element to be zeroed in the residual pathways.
|
| 314 |
-
"""
|
| 315 |
-
cnn_spatial_output = image_size_pixels // (2**n_downscale)
|
| 316 |
-
|
| 317 |
-
if not (cnn_spatial_output > 0):
|
| 318 |
-
raise ValueError(
|
| 319 |
-
f"cannot use this many downscaling layers ({n_downscale}) with this input "
|
| 320 |
-
f"spatial size ({image_size_pixels})"
|
| 321 |
-
)
|
| 322 |
-
|
| 323 |
-
super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
|
| 324 |
-
|
| 325 |
-
self.first_layer = nn.Sequential(
|
| 326 |
-
nn.Conv3d(
|
| 327 |
-
in_channels=in_channels,
|
| 328 |
-
out_channels=conv3d_channels,
|
| 329 |
-
kernel_size=(1, 1, 1),
|
| 330 |
-
padding=(0, 0, 0),
|
| 331 |
-
),
|
| 332 |
-
ResidualConv3dBlock(
|
| 333 |
-
in_channels=conv3d_channels,
|
| 334 |
-
n_layers=res_block_layers,
|
| 335 |
-
dropout_frac=dropout_frac,
|
| 336 |
-
),
|
| 337 |
-
)
|
| 338 |
-
|
| 339 |
-
downscale_layers = []
|
| 340 |
-
for _ in range(n_downscale):
|
| 341 |
-
downscale_layers += [
|
| 342 |
-
nn.Sequential(
|
| 343 |
-
ResidualConv3dBlock(
|
| 344 |
-
in_channels=conv3d_channels,
|
| 345 |
-
n_layers=res_block_layers,
|
| 346 |
-
dropout_frac=dropout_frac,
|
| 347 |
-
),
|
| 348 |
-
nn.ELU(),
|
| 349 |
-
nn.Conv3d(
|
| 350 |
-
in_channels=conv3d_channels,
|
| 351 |
-
out_channels=conv3d_channels,
|
| 352 |
-
kernel_size=(1, 2, 2),
|
| 353 |
-
padding=(0, 0, 0),
|
| 354 |
-
stride=(1, 2, 2),
|
| 355 |
-
),
|
| 356 |
-
)
|
| 357 |
-
]
|
| 358 |
-
|
| 359 |
-
self.downscale_layers = nn.ModuleList(downscale_layers)
|
| 360 |
-
|
| 361 |
-
self.crop_fn = CenterCrop(cnn_spatial_output)
|
| 362 |
-
|
| 363 |
-
cat_channels = conv3d_channels * (1 + n_downscale)
|
| 364 |
-
self.post_cat_conv = nn.Sequential(
|
| 365 |
-
ResidualConv3dBlock(
|
| 366 |
-
in_channels=cat_channels,
|
| 367 |
-
n_layers=res_block_layers,
|
| 368 |
-
),
|
| 369 |
-
nn.ELU(),
|
| 370 |
-
nn.Conv3d(
|
| 371 |
-
in_channels=cat_channels,
|
| 372 |
-
out_channels=conv3d_channels,
|
| 373 |
-
kernel_size=(1, 1, 1),
|
| 374 |
-
),
|
| 375 |
-
)
|
| 376 |
-
|
| 377 |
-
final_channels = (
|
| 378 |
-
(image_size_pixels // (2**n_downscale)) ** 2 * conv3d_channels * sequence_length
|
| 379 |
-
)
|
| 380 |
-
self.final_layer = nn.Sequential(
|
| 381 |
-
nn.ELU(),
|
| 382 |
-
nn.Linear(
|
| 383 |
-
in_features=final_channels,
|
| 384 |
-
out_features=out_features,
|
| 385 |
-
),
|
| 386 |
-
nn.ELU(),
|
| 387 |
-
)
|
| 388 |
-
|
| 389 |
-
def forward(self, x):
|
| 390 |
-
"""Run model forward"""
|
| 391 |
-
out = self.first_layer(x)
|
| 392 |
-
outputs = [self.crop_fn(out)]
|
| 393 |
-
|
| 394 |
-
for layer in self.downscale_layers:
|
| 395 |
-
out = layer(out)
|
| 396 |
-
outputs += [self.crop_fn(out)]
|
| 397 |
-
|
| 398 |
-
out = torch.cat(outputs, dim=1)
|
| 399 |
-
out = self.post_cat_conv(out)
|
| 400 |
-
out = torch.flatten(out, start_dim=1)
|
| 401 |
-
out = self.final_layer(out)
|
| 402 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/multimodal/encoders/encodersRNN.py
DELETED
|
@@ -1,141 +0,0 @@
|
|
| 1 |
-
"""Encoder modules for the satellite/NWP data based on recursive and 2D convolutional layers.
|
| 2 |
-
"""
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
|
| 7 |
-
from pvnet.models.multimodal.encoders.basic_blocks import (
|
| 8 |
-
AbstractNWPSatelliteEncoder,
|
| 9 |
-
ImageSequenceEncoder,
|
| 10 |
-
)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class ConvLSTM(AbstractNWPSatelliteEncoder):
|
| 14 |
-
"""Convolutional LSTM block from MetNet."""
|
| 15 |
-
|
| 16 |
-
def __init__(
|
| 17 |
-
self,
|
| 18 |
-
sequence_length: int,
|
| 19 |
-
image_size_pixels: int,
|
| 20 |
-
in_channels: int,
|
| 21 |
-
out_features: int,
|
| 22 |
-
hidden_channels: int = 32,
|
| 23 |
-
num_layers: int = 2,
|
| 24 |
-
kernel_size: int = 3,
|
| 25 |
-
bias: bool = True,
|
| 26 |
-
activation=torch.tanh,
|
| 27 |
-
batchnorm=False,
|
| 28 |
-
):
|
| 29 |
-
"""Convolutional LSTM block from MetNet.
|
| 30 |
-
|
| 31 |
-
Args:
|
| 32 |
-
sequence_length: The time sequence length of the data.
|
| 33 |
-
image_size_pixels: The spatial size of the image. Assumed square.
|
| 34 |
-
in_channels: Number of input channels.
|
| 35 |
-
out_features: Number of output features.
|
| 36 |
-
hidden_channels: Hidden dimension size.
|
| 37 |
-
num_layers: Depth of ConvLSTM cells.
|
| 38 |
-
kernel_size: Kernel size.
|
| 39 |
-
bias: Whether to add bias.
|
| 40 |
-
activation: Activation function for ConvLSTM cells.
|
| 41 |
-
batchnorm: Whether to use batch norm.
|
| 42 |
-
"""
|
| 43 |
-
from metnet.layers.ConvLSTM import ConvLSTM as _ConvLSTM
|
| 44 |
-
|
| 45 |
-
super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
|
| 46 |
-
|
| 47 |
-
self.conv_lstm = _ConvLSTM(
|
| 48 |
-
input_dim=in_channels,
|
| 49 |
-
hidden_dim=hidden_channels,
|
| 50 |
-
kernel_size=kernel_size,
|
| 51 |
-
num_layers=num_layers,
|
| 52 |
-
bias=bias,
|
| 53 |
-
activation=activation,
|
| 54 |
-
batchnorm=batchnorm,
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
# Calculate the size of the output of the ConvLSTM network
|
| 58 |
-
convlstm_output_size = hidden_channels * image_size_pixels**2
|
| 59 |
-
|
| 60 |
-
self.final_block = nn.Sequential(
|
| 61 |
-
nn.Linear(in_features=convlstm_output_size, out_features=out_features),
|
| 62 |
-
nn.ELU(),
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
def forward(self, x):
|
| 66 |
-
"""Run model forward"""
|
| 67 |
-
|
| 68 |
-
batch_size, channel, seq_len, height, width = x.shape
|
| 69 |
-
x = torch.swapaxes(x, 1, 2)
|
| 70 |
-
|
| 71 |
-
res, _ = self.conv_lstm(x)
|
| 72 |
-
|
| 73 |
-
# Select last state only
|
| 74 |
-
out = res[:, -1]
|
| 75 |
-
|
| 76 |
-
# Flatten and fully connected layer
|
| 77 |
-
out = out.reshape(batch_size, -1)
|
| 78 |
-
out = self.final_block(out)
|
| 79 |
-
|
| 80 |
-
return out
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
class FlattenLSTM(AbstractNWPSatelliteEncoder):
|
| 84 |
-
"""Convolutional blocks followed by LSTM."""
|
| 85 |
-
|
| 86 |
-
def __init__(
|
| 87 |
-
self,
|
| 88 |
-
sequence_length: int,
|
| 89 |
-
image_size_pixels: int,
|
| 90 |
-
in_channels: int,
|
| 91 |
-
out_features: int,
|
| 92 |
-
num_layers: int = 2,
|
| 93 |
-
number_of_conv2d_layers: int = 4,
|
| 94 |
-
conv2d_channels: int = 32,
|
| 95 |
-
):
|
| 96 |
-
"""Network consisting of 2D spatial convolutional and LSTM sequence encoder.
|
| 97 |
-
|
| 98 |
-
Args:
|
| 99 |
-
sequence_length: The time sequence length of the data.
|
| 100 |
-
image_size_pixels: The spatial size of the image. Assumed square.
|
| 101 |
-
in_channels: Number of input channels.
|
| 102 |
-
out_features: Number of output features. Also used for LSTM hidden dimension.
|
| 103 |
-
num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking
|
| 104 |
-
two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of
|
| 105 |
-
the first LSTM and computing the final results.
|
| 106 |
-
number_of_conv2d_layers: Number of convolution 2D layers that are used.
|
| 107 |
-
conv2d_channels: Number of channels used in each conv2d layer.
|
| 108 |
-
"""
|
| 109 |
-
|
| 110 |
-
super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
|
| 111 |
-
|
| 112 |
-
self.lstm = nn.LSTM(
|
| 113 |
-
input_size=out_features,
|
| 114 |
-
hidden_size=out_features,
|
| 115 |
-
num_layers=num_layers,
|
| 116 |
-
batch_first=True,
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
self.encode_image_sequence = ImageSequenceEncoder(
|
| 120 |
-
image_size_pixels=image_size_pixels,
|
| 121 |
-
in_channels=in_channels,
|
| 122 |
-
number_of_conv2d_layers=number_of_conv2d_layers,
|
| 123 |
-
conv2d_channels=conv2d_channels,
|
| 124 |
-
fc_features=out_features,
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
self.final_block = nn.Sequential(
|
| 128 |
-
nn.Linear(in_features=out_features, out_features=out_features),
|
| 129 |
-
nn.ELU(),
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
def forward(self, x):
|
| 133 |
-
"""Run model forward"""
|
| 134 |
-
encoded_images = self.encode_image_sequence(x)
|
| 135 |
-
|
| 136 |
-
_, (_, c_n) = self.lstm(encoded_images)
|
| 137 |
-
|
| 138 |
-
# Take only the deepest level hidden cell state
|
| 139 |
-
out = self.final_block(c_n[-1])
|
| 140 |
-
|
| 141 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/multimodal/linear_networks/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
"""Submodels to combine 1D feature vectors from different sources and make final predictions"""
|
|
|
|
|
|
pvnet/models/multimodal/linear_networks/basic_blocks.py
DELETED
|
@@ -1,121 +0,0 @@
|
|
| 1 |
-
"""Basic blocks for the lienar networks"""
|
| 2 |
-
from abc import ABCMeta, abstractmethod
|
| 3 |
-
from collections import OrderedDict
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
from torch import nn
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class AbstractLinearNetwork(nn.Module, metaclass=ABCMeta):
|
| 10 |
-
"""Abstract class for a network to combine the features from all the inputs."""
|
| 11 |
-
|
| 12 |
-
def __init__(
|
| 13 |
-
self,
|
| 14 |
-
in_features: int,
|
| 15 |
-
out_features: int,
|
| 16 |
-
):
|
| 17 |
-
"""Abstract class for a network to combine the features from all the inputs.
|
| 18 |
-
|
| 19 |
-
Args:
|
| 20 |
-
in_features: Number of input features.
|
| 21 |
-
out_features: Number of output features.
|
| 22 |
-
"""
|
| 23 |
-
super().__init__()
|
| 24 |
-
|
| 25 |
-
def cat_modes(self, x):
|
| 26 |
-
"""Concatenate modes of input data into 1D feature vector"""
|
| 27 |
-
if isinstance(x, OrderedDict):
|
| 28 |
-
return torch.cat([value for key, value in x.items()], dim=1)
|
| 29 |
-
elif isinstance(x, torch.Tensor):
|
| 30 |
-
return x
|
| 31 |
-
else:
|
| 32 |
-
raise ValueError(f"Input of unexpected type {type(x)}")
|
| 33 |
-
|
| 34 |
-
@abstractmethod
|
| 35 |
-
def forward(self):
|
| 36 |
-
"""Run model forward"""
|
| 37 |
-
pass
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class ResidualLinearBlock(nn.Module):
|
| 41 |
-
"""A 1D fully-connected residual block using ELU activations and including optional dropout."""
|
| 42 |
-
|
| 43 |
-
def __init__(
|
| 44 |
-
self,
|
| 45 |
-
in_features: int,
|
| 46 |
-
n_layers: int = 2,
|
| 47 |
-
dropout_frac: float = 0.0,
|
| 48 |
-
):
|
| 49 |
-
"""A 1D fully-connected residual block using ELU activations and including optional dropout.
|
| 50 |
-
|
| 51 |
-
Args:
|
| 52 |
-
in_features: Number of input features.
|
| 53 |
-
n_layers: Number of layers in residual pathway.
|
| 54 |
-
dropout_frac: Probability of an element to be zeroed.
|
| 55 |
-
"""
|
| 56 |
-
super().__init__()
|
| 57 |
-
|
| 58 |
-
layers = []
|
| 59 |
-
for i in range(n_layers):
|
| 60 |
-
layers += [
|
| 61 |
-
nn.ELU(),
|
| 62 |
-
nn.Linear(
|
| 63 |
-
in_features=in_features,
|
| 64 |
-
out_features=in_features,
|
| 65 |
-
),
|
| 66 |
-
nn.Dropout(p=dropout_frac),
|
| 67 |
-
]
|
| 68 |
-
self.model = nn.Sequential(*layers)
|
| 69 |
-
|
| 70 |
-
def forward(self, x):
|
| 71 |
-
"""Run model forward"""
|
| 72 |
-
return self.model(x) + x
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
class ResidualLinearBlock2(nn.Module):
|
| 76 |
-
"""Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1].
|
| 77 |
-
|
| 78 |
-
This was the best performing residual block tested in the study. This implementation differs
|
| 79 |
-
from that block just by using LeakyReLU activation to avoid dead neuron, and by including
|
| 80 |
-
optional dropout in the residual branch. This is also a 1D fully connected layer residual block
|
| 81 |
-
rather than a 2D convolutional block.
|
| 82 |
-
|
| 83 |
-
Sources:
|
| 84 |
-
[1] https://arxiv.org/pdf/1603.05027.pdf
|
| 85 |
-
"""
|
| 86 |
-
|
| 87 |
-
def __init__(
|
| 88 |
-
self,
|
| 89 |
-
in_features: int,
|
| 90 |
-
n_layers: int = 2,
|
| 91 |
-
dropout_frac: float = 0.0,
|
| 92 |
-
):
|
| 93 |
-
"""Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1].
|
| 94 |
-
|
| 95 |
-
Sources:
|
| 96 |
-
[1] https://arxiv.org/pdf/1603.05027.pdf
|
| 97 |
-
|
| 98 |
-
Args:
|
| 99 |
-
in_features: Number of input features.
|
| 100 |
-
n_layers: Number of layers in residual pathway.
|
| 101 |
-
dropout_frac: Probability of an element to be zeroed.
|
| 102 |
-
"""
|
| 103 |
-
super().__init__()
|
| 104 |
-
|
| 105 |
-
layers = []
|
| 106 |
-
for i in range(n_layers):
|
| 107 |
-
layers += [
|
| 108 |
-
nn.BatchNorm1d(in_features),
|
| 109 |
-
nn.Dropout(p=dropout_frac),
|
| 110 |
-
nn.LeakyReLU(),
|
| 111 |
-
nn.Linear(
|
| 112 |
-
in_features=in_features,
|
| 113 |
-
out_features=in_features,
|
| 114 |
-
),
|
| 115 |
-
]
|
| 116 |
-
|
| 117 |
-
self.model = nn.Sequential(*layers)
|
| 118 |
-
|
| 119 |
-
def forward(self, x):
|
| 120 |
-
"""Run model forward"""
|
| 121 |
-
return self.model(x) + x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/multimodal/linear_networks/networks.py
DELETED
|
@@ -1,332 +0,0 @@
|
|
| 1 |
-
"""Linear networks used for the fusion model"""
|
| 2 |
-
from torch import nn, rand
|
| 3 |
-
|
| 4 |
-
from pvnet.models.multimodal.linear_networks.basic_blocks import (
|
| 5 |
-
AbstractLinearNetwork,
|
| 6 |
-
ResidualLinearBlock,
|
| 7 |
-
ResidualLinearBlock2,
|
| 8 |
-
)
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class DefaultFCNet(AbstractLinearNetwork):
|
| 12 |
-
"""Similar to the original FCNet module used in PVNet, with a few minor tweaks.
|
| 13 |
-
|
| 14 |
-
This is a 2-layer fully connected block, with internal ELU activations and output ReLU.
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
def __init__(
|
| 18 |
-
self,
|
| 19 |
-
in_features: int,
|
| 20 |
-
out_features: int,
|
| 21 |
-
fc_hidden_features: int = 128,
|
| 22 |
-
):
|
| 23 |
-
"""Similar to the original FCNet module used in PVNet, with a few minor tweaks.
|
| 24 |
-
|
| 25 |
-
Args:
|
| 26 |
-
in_features: Number of input features.
|
| 27 |
-
out_features: Number of output features.
|
| 28 |
-
fc_hidden_features: Number of features in middle hidden layer.
|
| 29 |
-
"""
|
| 30 |
-
super().__init__(in_features, out_features)
|
| 31 |
-
|
| 32 |
-
self.model = nn.Sequential(
|
| 33 |
-
nn.Linear(in_features=in_features, out_features=fc_hidden_features),
|
| 34 |
-
nn.ELU(),
|
| 35 |
-
nn.Linear(in_features=fc_hidden_features, out_features=out_features),
|
| 36 |
-
nn.ReLU(),
|
| 37 |
-
)
|
| 38 |
-
|
| 39 |
-
def forward(self, x):
|
| 40 |
-
"""Run model forward"""
|
| 41 |
-
x = self.cat_modes(x)
|
| 42 |
-
return self.model(x)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class ResFCNet(AbstractLinearNetwork):
|
| 46 |
-
"""Fully-connected deep network based on ResNet architecture.
|
| 47 |
-
|
| 48 |
-
Internally, this network uses ELU activations throughout the residual blocks.
|
| 49 |
-
With n_res_blocks=0 this becomes equivalent to `DefaultFCNet`.
|
| 50 |
-
"""
|
| 51 |
-
|
| 52 |
-
def __init__(
|
| 53 |
-
self,
|
| 54 |
-
in_features: int,
|
| 55 |
-
out_features: int,
|
| 56 |
-
fc_hidden_features: int = 128,
|
| 57 |
-
n_res_blocks: int = 4,
|
| 58 |
-
res_block_layers: int = 2,
|
| 59 |
-
dropout_frac: float = 0.2,
|
| 60 |
-
):
|
| 61 |
-
"""Fully-connected deep network based on ResNet architecture.
|
| 62 |
-
|
| 63 |
-
Args:
|
| 64 |
-
in_features: Number of input features.
|
| 65 |
-
out_features: Number of output features.
|
| 66 |
-
fc_hidden_features: Number of features in middle hidden layers.
|
| 67 |
-
n_res_blocks: Number of residual blocks to use.
|
| 68 |
-
res_block_layers: Number of fully-connected layers used in each residual block.
|
| 69 |
-
dropout_frac: Probability of an element to be zeroed in the residual pathways.
|
| 70 |
-
"""
|
| 71 |
-
super().__init__(in_features, out_features)
|
| 72 |
-
|
| 73 |
-
model = [
|
| 74 |
-
nn.Linear(in_features=in_features, out_features=fc_hidden_features),
|
| 75 |
-
]
|
| 76 |
-
|
| 77 |
-
for i in range(n_res_blocks):
|
| 78 |
-
model += [
|
| 79 |
-
ResidualLinearBlock(
|
| 80 |
-
in_features=fc_hidden_features,
|
| 81 |
-
n_layers=res_block_layers,
|
| 82 |
-
dropout_frac=dropout_frac,
|
| 83 |
-
)
|
| 84 |
-
]
|
| 85 |
-
|
| 86 |
-
model += [
|
| 87 |
-
nn.ELU(),
|
| 88 |
-
nn.Linear(in_features=fc_hidden_features, out_features=out_features),
|
| 89 |
-
nn.LeakyReLU(negative_slope=0.01),
|
| 90 |
-
]
|
| 91 |
-
self.model = nn.Sequential(*model)
|
| 92 |
-
|
| 93 |
-
def forward(self, x):
|
| 94 |
-
"""Run model forward"""
|
| 95 |
-
x = self.cat_modes(x)
|
| 96 |
-
return self.model(x)
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
class ResFCNet2(AbstractLinearNetwork):
|
| 100 |
-
"""Fully connected deep network based on ResNet architecture.
|
| 101 |
-
|
| 102 |
-
This architecture is similar to
|
| 103 |
-
`ResFCNet`, except that it uses LeakyReLU activations internally, and batchnorm in the residual
|
| 104 |
-
branches. The residual blocks are implemented based on the best performing block in [1].
|
| 105 |
-
|
| 106 |
-
Sources:
|
| 107 |
-
[1] https://arxiv.org/pdf/1603.05027.pdf
|
| 108 |
-
"""
|
| 109 |
-
|
| 110 |
-
def __init__(
|
| 111 |
-
self,
|
| 112 |
-
in_features: int,
|
| 113 |
-
out_features: int,
|
| 114 |
-
fc_hidden_features: int = 128,
|
| 115 |
-
n_res_blocks: int = 4,
|
| 116 |
-
res_block_layers: int = 2,
|
| 117 |
-
dropout_frac=0.0,
|
| 118 |
-
):
|
| 119 |
-
"""Fully connected deep network based on ResNet architecture.
|
| 120 |
-
|
| 121 |
-
Args:
|
| 122 |
-
in_features: Number of input features.
|
| 123 |
-
out_features: Number of output features.
|
| 124 |
-
fc_hidden_features: Number of features in middle hidden layers.
|
| 125 |
-
n_res_blocks: Number of residual blocks to use.
|
| 126 |
-
res_block_layers: Number of fully-connected layers used in each residual block.
|
| 127 |
-
dropout_frac: Probability of an element to be zeroed in the residual pathways.
|
| 128 |
-
"""
|
| 129 |
-
super().__init__(in_features, out_features)
|
| 130 |
-
|
| 131 |
-
model = [
|
| 132 |
-
nn.Linear(in_features=in_features, out_features=fc_hidden_features),
|
| 133 |
-
]
|
| 134 |
-
|
| 135 |
-
for i in range(n_res_blocks):
|
| 136 |
-
model += [
|
| 137 |
-
ResidualLinearBlock2(
|
| 138 |
-
in_features=fc_hidden_features,
|
| 139 |
-
n_layers=res_block_layers,
|
| 140 |
-
dropout_frac=dropout_frac,
|
| 141 |
-
)
|
| 142 |
-
]
|
| 143 |
-
|
| 144 |
-
model += [
|
| 145 |
-
nn.LeakyReLU(),
|
| 146 |
-
nn.Linear(in_features=fc_hidden_features, out_features=out_features),
|
| 147 |
-
nn.LeakyReLU(negative_slope=0.01),
|
| 148 |
-
]
|
| 149 |
-
|
| 150 |
-
self.model = nn.Sequential(*model)
|
| 151 |
-
|
| 152 |
-
def forward(self, x):
|
| 153 |
-
"""Run model forward"""
|
| 154 |
-
x = self.cat_modes(x)
|
| 155 |
-
return self.model(x)
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
class SNN(AbstractLinearNetwork):
|
| 159 |
-
"""Self normalising neural network implementation borrowed from [1] and proposed in [2].
|
| 160 |
-
|
| 161 |
-
Sources:
|
| 162 |
-
[1] https://github.com/tonyduan/snn/blob/master/snn/models.py
|
| 163 |
-
[2] https://arxiv.org/pdf/1706.02515v5.pdf
|
| 164 |
-
|
| 165 |
-
Args:
|
| 166 |
-
in_features: Number of input features.
|
| 167 |
-
out_features: Number of output features.
|
| 168 |
-
fc_hidden_features: Number of features in middle hidden layers.
|
| 169 |
-
n_layers: Number of fully-connected layers used in the network.
|
| 170 |
-
dropout_frac: Probability of an element to be zeroed.
|
| 171 |
-
|
| 172 |
-
"""
|
| 173 |
-
|
| 174 |
-
def __init__(
|
| 175 |
-
self,
|
| 176 |
-
in_features: int,
|
| 177 |
-
out_features: int,
|
| 178 |
-
fc_hidden_features: int = 128,
|
| 179 |
-
n_layers: int = 10,
|
| 180 |
-
dropout_frac: float = 0.0,
|
| 181 |
-
):
|
| 182 |
-
"""Self normalising neural network implementation borrowed from [1] and proposed in [2].
|
| 183 |
-
|
| 184 |
-
Sources:
|
| 185 |
-
[1] https://github.com/tonyduan/snn/blob/master/snn/models.py
|
| 186 |
-
[2] https://arxiv.org/pdf/1706.02515v5.pdf
|
| 187 |
-
|
| 188 |
-
Args:
|
| 189 |
-
in_features: Number of input features.
|
| 190 |
-
out_features: Number of output features.
|
| 191 |
-
fc_hidden_features: Number of features in middle hidden layers.
|
| 192 |
-
n_layers: Number of fully-connected layers used in the network.
|
| 193 |
-
dropout_frac: Probability of an element to be zeroed.
|
| 194 |
-
|
| 195 |
-
"""
|
| 196 |
-
super().__init__(in_features, out_features)
|
| 197 |
-
|
| 198 |
-
layers = [
|
| 199 |
-
nn.Linear(in_features, fc_hidden_features, bias=False),
|
| 200 |
-
nn.SELU(),
|
| 201 |
-
nn.AlphaDropout(p=dropout_frac),
|
| 202 |
-
]
|
| 203 |
-
for i in range(1, n_layers - 1):
|
| 204 |
-
layers += [
|
| 205 |
-
nn.Linear(fc_hidden_features, fc_hidden_features, bias=False),
|
| 206 |
-
nn.SELU(),
|
| 207 |
-
nn.AlphaDropout(p=dropout_frac),
|
| 208 |
-
]
|
| 209 |
-
layers += [
|
| 210 |
-
nn.Linear(fc_hidden_features, out_features, bias=True),
|
| 211 |
-
nn.LeakyReLU(negative_slope=0.01),
|
| 212 |
-
]
|
| 213 |
-
|
| 214 |
-
self.network = nn.Sequential(*layers)
|
| 215 |
-
self._reset_parameters()
|
| 216 |
-
|
| 217 |
-
def forward(self, x):
|
| 218 |
-
"""Run model forward"""
|
| 219 |
-
x = self.cat_modes(x)
|
| 220 |
-
return self.network(x)
|
| 221 |
-
|
| 222 |
-
def _reset_parameters(self):
|
| 223 |
-
for layer in self.network:
|
| 224 |
-
if isinstance(layer, nn.Linear):
|
| 225 |
-
nn.init.normal_(layer.weight, std=layer.out_features**-0.5)
|
| 226 |
-
if layer.bias is not None:
|
| 227 |
-
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(layer.weight)
|
| 228 |
-
bound = fan_in**-0.5
|
| 229 |
-
nn.init.uniform_(layer.bias, -bound, bound)
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
class TabNet(AbstractLinearNetwork):
|
| 233 |
-
"""An implmentation of TabNet [1].
|
| 234 |
-
|
| 235 |
-
The implementation comes rom `pytorch_tabnet` and this must be installed for use.
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
Sources:
|
| 239 |
-
[1] https://arxiv.org/abs/1908.07442
|
| 240 |
-
"""
|
| 241 |
-
|
| 242 |
-
def __init__(
|
| 243 |
-
self,
|
| 244 |
-
in_features: int,
|
| 245 |
-
out_features: int,
|
| 246 |
-
n_d=8,
|
| 247 |
-
n_a=8,
|
| 248 |
-
n_steps=3,
|
| 249 |
-
gamma=1.3,
|
| 250 |
-
cat_idxs=[],
|
| 251 |
-
cat_dims=[],
|
| 252 |
-
cat_emb_dim=1,
|
| 253 |
-
n_independent=2,
|
| 254 |
-
n_shared=2,
|
| 255 |
-
epsilon=1e-15,
|
| 256 |
-
virtual_batch_size=128,
|
| 257 |
-
momentum=0.02,
|
| 258 |
-
mask_type="sparsemax",
|
| 259 |
-
):
|
| 260 |
-
"""An implmentation of TabNet [1].
|
| 261 |
-
|
| 262 |
-
Sources:
|
| 263 |
-
[1] https://arxiv.org/abs/1908.07442
|
| 264 |
-
|
| 265 |
-
Args:
|
| 266 |
-
in_features: int
|
| 267 |
-
Number of input features.
|
| 268 |
-
out_features: int
|
| 269 |
-
Number of output features.
|
| 270 |
-
n_d : int
|
| 271 |
-
Dimension of the prediction layer (usually between 4 and 64)
|
| 272 |
-
n_a : int
|
| 273 |
-
Dimension of the attention layer (usually between 4 and 64)
|
| 274 |
-
n_steps : int
|
| 275 |
-
Number of successive steps in the network (usually between 3 and 10)
|
| 276 |
-
gamma : float
|
| 277 |
-
Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0)
|
| 278 |
-
cat_idxs : list of int
|
| 279 |
-
Index of each categorical column in the dataset
|
| 280 |
-
cat_dims : list of int
|
| 281 |
-
Number of categories in each categorical column
|
| 282 |
-
cat_emb_dim : int or list of int
|
| 283 |
-
Size of the embedding of categorical features
|
| 284 |
-
if int, all categorical features will have same embedding size
|
| 285 |
-
if list of int, every corresponding feature will have specific size
|
| 286 |
-
n_independent : int
|
| 287 |
-
Number of independent GLU layer in each GLU block (default 2)
|
| 288 |
-
n_shared : int
|
| 289 |
-
Number of independent GLU layer in each GLU block (default 2)
|
| 290 |
-
epsilon : float
|
| 291 |
-
Avoid log(0), this should be kept very low
|
| 292 |
-
virtual_batch_size : int
|
| 293 |
-
Batch size for Ghost Batch Normalization
|
| 294 |
-
momentum : float
|
| 295 |
-
Float value between 0 and 1 which will be used for momentum in all batch norm
|
| 296 |
-
mask_type : str
|
| 297 |
-
Either "sparsemax" or "entmax" : this is the masking function to use
|
| 298 |
-
"""
|
| 299 |
-
from pytorch_tabnet.tab_network import TabNet as _TabNetModel
|
| 300 |
-
|
| 301 |
-
super().__init__(in_features, out_features)
|
| 302 |
-
|
| 303 |
-
self._tabnet = _TabNetModel(
|
| 304 |
-
input_dim=in_features,
|
| 305 |
-
output_dim=out_features,
|
| 306 |
-
n_d=n_d,
|
| 307 |
-
n_a=n_a,
|
| 308 |
-
n_steps=n_steps,
|
| 309 |
-
gamma=gamma,
|
| 310 |
-
cat_idxs=cat_idxs,
|
| 311 |
-
cat_dims=cat_dims,
|
| 312 |
-
cat_emb_dim=cat_emb_dim,
|
| 313 |
-
n_independent=n_independent,
|
| 314 |
-
n_shared=n_shared,
|
| 315 |
-
epsilon=epsilon,
|
| 316 |
-
virtual_batch_size=virtual_batch_size,
|
| 317 |
-
momentum=momentum,
|
| 318 |
-
mask_type=mask_type,
|
| 319 |
-
group_attention_matrix=rand(4, in_features),
|
| 320 |
-
)
|
| 321 |
-
|
| 322 |
-
self.activation = nn.LeakyReLU(negative_slope=0.01)
|
| 323 |
-
|
| 324 |
-
def forward(self, x):
|
| 325 |
-
"""Run model forward"""
|
| 326 |
-
# TODO: USE THIS LOSS COMPONENT
|
| 327 |
-
# loss = self.compute_loss(output, y)
|
| 328 |
-
# Add the overall sparsity loss
|
| 329 |
-
# loss = loss - self.lambda_sparse * M_loss
|
| 330 |
-
x = self.cat_modes(x)
|
| 331 |
-
out1, M_loss = self._tabnet(x)
|
| 332 |
-
return self.activation(out1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/multimodal/multimodal.py
DELETED
|
@@ -1,417 +0,0 @@
|
|
| 1 |
-
"""The default composite model architecture for PVNet"""
|
| 2 |
-
|
| 3 |
-
import logging
|
| 4 |
-
from collections import OrderedDict
|
| 5 |
-
from typing import Any, Optional
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
from omegaconf import DictConfig
|
| 9 |
-
from torch import nn
|
| 10 |
-
|
| 11 |
-
import pvnet
|
| 12 |
-
from pvnet.models.base_model import BaseModel
|
| 13 |
-
from pvnet.models.multimodal.basic_blocks import ImageEmbedding
|
| 14 |
-
from pvnet.models.multimodal.encoders.basic_blocks import AbstractNWPSatelliteEncoder
|
| 15 |
-
from pvnet.models.multimodal.linear_networks.basic_blocks import AbstractLinearNetwork
|
| 16 |
-
from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractSitesEncoder
|
| 17 |
-
from pvnet.optimizers import AbstractOptimizer
|
| 18 |
-
|
| 19 |
-
logger = logging.getLogger(__name__)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class Model(BaseModel):
|
| 23 |
-
"""Neural network which combines information from different sources
|
| 24 |
-
|
| 25 |
-
Architecture is roughly as follows:
|
| 26 |
-
|
| 27 |
-
- Satellite data, if included, is put through an encoder which transforms it from 4D, with time,
|
| 28 |
-
channel, height, and width dimensions to become a 1D feature vector.
|
| 29 |
-
- NWP, if included, is put through a similar encoder.
|
| 30 |
-
- PV site-level data, if included, is put through an encoder which transforms it from 2D, with
|
| 31 |
-
time and system-ID dimensions, to become a 1D feature vector.
|
| 32 |
-
- The satellite features*, NWP features*, PV site-level features*, GSP ID embedding*, and sun
|
| 33 |
-
paramters* are concatenated into a 1D feature vector and passed through another neural
|
| 34 |
-
network to combine them and produce a forecast.
|
| 35 |
-
|
| 36 |
-
* if included
|
| 37 |
-
"""
|
| 38 |
-
|
| 39 |
-
name = "conv3d_sat_nwp"
|
| 40 |
-
|
| 41 |
-
def __init__(
|
| 42 |
-
self,
|
| 43 |
-
output_network: AbstractLinearNetwork,
|
| 44 |
-
output_quantiles: Optional[list[float]] = None,
|
| 45 |
-
nwp_encoders_dict: Optional[dict[AbstractNWPSatelliteEncoder]] = None,
|
| 46 |
-
sat_encoder: Optional[AbstractNWPSatelliteEncoder] = None,
|
| 47 |
-
pv_encoder: Optional[AbstractSitesEncoder] = None,
|
| 48 |
-
sensor_encoder: Optional[AbstractSitesEncoder] = None,
|
| 49 |
-
add_image_embedding_channel: bool = False,
|
| 50 |
-
include_gsp_yield_history: bool = True,
|
| 51 |
-
include_site_yield_history: Optional[bool] = False,
|
| 52 |
-
include_sun: bool = True,
|
| 53 |
-
include_time: bool = False,
|
| 54 |
-
location_id_mapping: Optional[dict[Any, int]] = None,
|
| 55 |
-
embedding_dim: Optional[int] = 16,
|
| 56 |
-
forecast_minutes: int = 30,
|
| 57 |
-
history_minutes: int = 60,
|
| 58 |
-
sat_history_minutes: Optional[int] = None,
|
| 59 |
-
min_sat_delay_minutes: Optional[int] = 30,
|
| 60 |
-
nwp_forecast_minutes: Optional[DictConfig] = None,
|
| 61 |
-
nwp_history_minutes: Optional[DictConfig] = None,
|
| 62 |
-
pv_history_minutes: Optional[int] = None,
|
| 63 |
-
sensor_history_minutes: Optional[int] = None,
|
| 64 |
-
sensor_forecast_minutes: Optional[int] = None,
|
| 65 |
-
optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
|
| 66 |
-
target_key: str = "gsp",
|
| 67 |
-
interval_minutes: int = 30,
|
| 68 |
-
nwp_interval_minutes: Optional[DictConfig] = None,
|
| 69 |
-
pv_interval_minutes: int = 5,
|
| 70 |
-
sat_interval_minutes: int = 5,
|
| 71 |
-
sensor_interval_minutes: int = 30,
|
| 72 |
-
timestep_intervals_to_plot: Optional[list[int]] = None,
|
| 73 |
-
adapt_batches: Optional[bool] = False,
|
| 74 |
-
forecast_minutes_ignore: Optional[int] = 0,
|
| 75 |
-
save_validation_results_csv: Optional[bool] = False,
|
| 76 |
-
):
|
| 77 |
-
"""Neural network which combines information from different sources.
|
| 78 |
-
|
| 79 |
-
Notes:
|
| 80 |
-
In the args, where it says a module `m` is partially instantiated, it means that a
|
| 81 |
-
normal pytorch module will be returned by running `mod = m(**kwargs)`. In this library,
|
| 82 |
-
this partial instantiation is generally achieved using partial instantiation via hydra.
|
| 83 |
-
However, the arg is still valid as long as `m(**kwargs)` returns a valid pytorch module
|
| 84 |
-
- for example if `m` is a regular function.
|
| 85 |
-
|
| 86 |
-
Args:
|
| 87 |
-
output_network: A partially instantiated pytorch Module class used to combine the 1D
|
| 88 |
-
features to produce the forecast.
|
| 89 |
-
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
|
| 90 |
-
None the output is a single value.
|
| 91 |
-
nwp_encoders_dict: A dictionary of partially instantiated pytorch Module class used to
|
| 92 |
-
encode the NWP data from 4D into a 1D feature vector from different sources.
|
| 93 |
-
sat_encoder: A partially instantiated pytorch Module class used to encode the satellite
|
| 94 |
-
data from 4D into a 1D feature vector.
|
| 95 |
-
pv_encoder: A partially instantiated pytorch Module class used to encode the site-level
|
| 96 |
-
PV data from 2D into a 1D feature vector.
|
| 97 |
-
add_image_embedding_channel: Add a channel to the NWP and satellite data with the
|
| 98 |
-
embedding of the GSP ID.
|
| 99 |
-
include_gsp_yield_history: Include GSP yield data.
|
| 100 |
-
include_site_yield_history: Include Site yield data.
|
| 101 |
-
include_sun: Include sun azimuth and altitude data.
|
| 102 |
-
include_time: Include sine and cosine of dates and times.
|
| 103 |
-
location_id_mapping: A dictionary mapping the location ID to an integer. ID embedding is
|
| 104 |
-
not used if this is not provided.
|
| 105 |
-
embedding_dim: Number of embedding dimensions to use for GSP ID.
|
| 106 |
-
forecast_minutes: The amount of minutes that should be forecasted.
|
| 107 |
-
history_minutes: The default amount of historical minutes that are used.
|
| 108 |
-
sat_history_minutes: Length of recent observations used for satellite inputs. Defaults
|
| 109 |
-
to `history_minutes` if not provided.
|
| 110 |
-
min_sat_delay_minutes: Minimum delay with respect to t0 of the latest available
|
| 111 |
-
satellite image.
|
| 112 |
-
nwp_forecast_minutes: Period of future NWP forecast data used as input. Defaults to
|
| 113 |
-
`forecast_minutes` if not provided.
|
| 114 |
-
nwp_history_minutes: Period of historical NWP forecast used as input. Defaults to
|
| 115 |
-
`history_minutes` if not provided.
|
| 116 |
-
pv_history_minutes: Length of recent site-level PV data used as
|
| 117 |
-
input. Defaults to `history_minutes` if not provided.
|
| 118 |
-
optimizer: Optimizer factory function used for network.
|
| 119 |
-
target_key: The key of the target variable in the batch.
|
| 120 |
-
interval_minutes: The interval between each sample of the target data
|
| 121 |
-
nwp_interval_minutes: Dictionary of the intervals between each sample of the NWP
|
| 122 |
-
data for each source
|
| 123 |
-
pv_interval_minutes: The interval between each sample of the PV data
|
| 124 |
-
sat_interval_minutes: The interval between each sample of the satellite data
|
| 125 |
-
sensor_interval_minutes: The interval between each sample of the sensor data
|
| 126 |
-
timestep_intervals_to_plot: Intervals, in timesteps, to plot in
|
| 127 |
-
addition to the full forecast
|
| 128 |
-
sensor_encoder: Encoder for sensor data
|
| 129 |
-
sensor_history_minutes: Length of recent sensor data used as input.
|
| 130 |
-
sensor_forecast_minutes: Length of forecast sensor data used as input.
|
| 131 |
-
adapt_batches: If set to true, we attempt to slice the batches to the expected shape for
|
| 132 |
-
the model to use. This allows us to overprepare batches and slice from them for the
|
| 133 |
-
data we need for a model run.
|
| 134 |
-
forecast_minutes_ignore: Number of forecast minutes to ignore when calculating losses.
|
| 135 |
-
For example if set to 60, the model doesnt predict the first 60 minutes
|
| 136 |
-
save_validation_results_csv: whether to save full csv outputs from validation results.
|
| 137 |
-
"""
|
| 138 |
-
|
| 139 |
-
self.include_gsp_yield_history = include_gsp_yield_history
|
| 140 |
-
self.include_site_yield_history = include_site_yield_history
|
| 141 |
-
self.include_sat = sat_encoder is not None
|
| 142 |
-
self.include_nwp = nwp_encoders_dict is not None and len(nwp_encoders_dict) != 0
|
| 143 |
-
self.include_pv = pv_encoder is not None
|
| 144 |
-
self.include_sun = include_sun
|
| 145 |
-
self.include_time = include_time
|
| 146 |
-
self.include_sensor = sensor_encoder is not None
|
| 147 |
-
self.location_id_mapping = location_id_mapping
|
| 148 |
-
self.embedding_dim = embedding_dim
|
| 149 |
-
self.add_image_embedding_channel = add_image_embedding_channel
|
| 150 |
-
self.interval_minutes = interval_minutes
|
| 151 |
-
self.min_sat_delay_minutes = min_sat_delay_minutes
|
| 152 |
-
self.adapt_batches = adapt_batches
|
| 153 |
-
|
| 154 |
-
if self.location_id_mapping is None:
|
| 155 |
-
logger.warning("location_id_mapping` is not provided, "
|
| 156 |
-
"defaulting to outdated GSP mapping (0 to 317)")
|
| 157 |
-
|
| 158 |
-
# Note 318 is the 2024 UK GSP count, so this is a temporary fix
|
| 159 |
-
# for models trained with this default embedding
|
| 160 |
-
self.location_id_mapping = {i: i for i in range(318)}
|
| 161 |
-
|
| 162 |
-
# in the future location_id_mapping could be None,
|
| 163 |
-
# and in this case use_id_embedding should be False
|
| 164 |
-
self.use_id_embedding = self.embedding_dim is not None
|
| 165 |
-
|
| 166 |
-
if self.use_id_embedding:
|
| 167 |
-
num_embeddings = max(self.location_id_mapping.values()) + 1
|
| 168 |
-
|
| 169 |
-
super().__init__(
|
| 170 |
-
history_minutes=history_minutes,
|
| 171 |
-
forecast_minutes=forecast_minutes,
|
| 172 |
-
optimizer=optimizer,
|
| 173 |
-
output_quantiles=output_quantiles,
|
| 174 |
-
target_key=target_key,
|
| 175 |
-
interval_minutes=interval_minutes,
|
| 176 |
-
timestep_intervals_to_plot=timestep_intervals_to_plot,
|
| 177 |
-
forecast_minutes_ignore=forecast_minutes_ignore,
|
| 178 |
-
save_validation_results_csv=save_validation_results_csv
|
| 179 |
-
)
|
| 180 |
-
|
| 181 |
-
# Number of features expected by the output_network
|
| 182 |
-
# Add to this as network pieces are constructed
|
| 183 |
-
fusion_input_features = 0
|
| 184 |
-
|
| 185 |
-
if self.include_sat:
|
| 186 |
-
# Param checks
|
| 187 |
-
assert sat_history_minutes is not None
|
| 188 |
-
|
| 189 |
-
self.sat_sequence_len = (
|
| 190 |
-
sat_history_minutes - min_sat_delay_minutes
|
| 191 |
-
) // sat_interval_minutes + 1
|
| 192 |
-
|
| 193 |
-
self.sat_encoder = sat_encoder(
|
| 194 |
-
sequence_length=self.sat_sequence_len,
|
| 195 |
-
in_channels=sat_encoder.keywords["in_channels"] + add_image_embedding_channel,
|
| 196 |
-
)
|
| 197 |
-
if add_image_embedding_channel:
|
| 198 |
-
self.sat_embed = ImageEmbedding(
|
| 199 |
-
num_embeddings, self.sat_sequence_len, self.sat_encoder.image_size_pixels
|
| 200 |
-
)
|
| 201 |
-
|
| 202 |
-
# Update num features
|
| 203 |
-
fusion_input_features += self.sat_encoder.out_features
|
| 204 |
-
|
| 205 |
-
if self.include_nwp:
|
| 206 |
-
# Param checks
|
| 207 |
-
assert nwp_forecast_minutes is not None
|
| 208 |
-
assert nwp_history_minutes is not None
|
| 209 |
-
|
| 210 |
-
# For each NWP encoder the forecast and history minutes must be set
|
| 211 |
-
assert set(nwp_encoders_dict.keys()) == set(nwp_forecast_minutes.keys())
|
| 212 |
-
assert set(nwp_encoders_dict.keys()) == set(nwp_history_minutes.keys())
|
| 213 |
-
|
| 214 |
-
if nwp_interval_minutes is None:
|
| 215 |
-
nwp_interval_minutes = dict.fromkeys(nwp_encoders_dict.keys(), 60)
|
| 216 |
-
|
| 217 |
-
self.nwp_encoders_dict = torch.nn.ModuleDict()
|
| 218 |
-
if add_image_embedding_channel:
|
| 219 |
-
self.nwp_embed_dict = torch.nn.ModuleDict()
|
| 220 |
-
|
| 221 |
-
for nwp_source in nwp_encoders_dict.keys():
|
| 222 |
-
nwp_sequence_len = (
|
| 223 |
-
nwp_history_minutes[nwp_source] // nwp_interval_minutes[nwp_source]
|
| 224 |
-
+ nwp_forecast_minutes[nwp_source] // nwp_interval_minutes[nwp_source]
|
| 225 |
-
+ 1
|
| 226 |
-
)
|
| 227 |
-
|
| 228 |
-
self.nwp_encoders_dict[nwp_source] = nwp_encoders_dict[nwp_source](
|
| 229 |
-
sequence_length=nwp_sequence_len,
|
| 230 |
-
in_channels=(
|
| 231 |
-
nwp_encoders_dict[nwp_source].keywords["in_channels"]
|
| 232 |
-
+ add_image_embedding_channel
|
| 233 |
-
),
|
| 234 |
-
)
|
| 235 |
-
if add_image_embedding_channel:
|
| 236 |
-
self.nwp_embed_dict[nwp_source] = ImageEmbedding(
|
| 237 |
-
num_embeddings,
|
| 238 |
-
nwp_sequence_len,
|
| 239 |
-
self.nwp_encoders_dict[nwp_source].image_size_pixels,
|
| 240 |
-
)
|
| 241 |
-
|
| 242 |
-
# Update num features
|
| 243 |
-
fusion_input_features += self.nwp_encoders_dict[nwp_source].out_features
|
| 244 |
-
|
| 245 |
-
if self.include_pv:
|
| 246 |
-
assert pv_history_minutes is not None
|
| 247 |
-
|
| 248 |
-
self.pv_encoder = pv_encoder(
|
| 249 |
-
sequence_length=pv_history_minutes // pv_interval_minutes + 1,
|
| 250 |
-
target_key_to_use=self._target_key,
|
| 251 |
-
input_key_to_use="site",
|
| 252 |
-
)
|
| 253 |
-
|
| 254 |
-
# Update num features
|
| 255 |
-
fusion_input_features += self.pv_encoder.out_features
|
| 256 |
-
|
| 257 |
-
if self.include_sensor:
|
| 258 |
-
if sensor_history_minutes is None:
|
| 259 |
-
sensor_history_minutes = history_minutes
|
| 260 |
-
if sensor_forecast_minutes is None:
|
| 261 |
-
sensor_forecast_minutes = forecast_minutes
|
| 262 |
-
|
| 263 |
-
self.sensor_encoder = sensor_encoder(
|
| 264 |
-
sequence_length=sensor_history_minutes // sensor_interval_minutes
|
| 265 |
-
+ sensor_forecast_minutes // sensor_interval_minutes
|
| 266 |
-
+ 1,
|
| 267 |
-
target_key_to_use=self._target_key,
|
| 268 |
-
input_key_to_use="sensor",
|
| 269 |
-
)
|
| 270 |
-
|
| 271 |
-
# Update num features
|
| 272 |
-
fusion_input_features += self.sensor_encoder.out_features
|
| 273 |
-
|
| 274 |
-
if self.use_id_embedding:
|
| 275 |
-
self.embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
| 276 |
-
|
| 277 |
-
# Update num features
|
| 278 |
-
fusion_input_features += embedding_dim
|
| 279 |
-
|
| 280 |
-
if self.include_sun:
|
| 281 |
-
self.sun_fc1 = nn.Linear(
|
| 282 |
-
in_features=2
|
| 283 |
-
* (self.forecast_len + self.forecast_len_ignore + self.history_len + 1),
|
| 284 |
-
out_features=16,
|
| 285 |
-
)
|
| 286 |
-
|
| 287 |
-
# Update num features
|
| 288 |
-
fusion_input_features += 16
|
| 289 |
-
|
| 290 |
-
if self.include_time:
|
| 291 |
-
self.time_fc1 = nn.Linear(
|
| 292 |
-
in_features=4
|
| 293 |
-
* (self.forecast_len + self.forecast_len_ignore + self.history_len + 1),
|
| 294 |
-
out_features=32,
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
# Update num features
|
| 298 |
-
fusion_input_features += 32
|
| 299 |
-
|
| 300 |
-
if include_gsp_yield_history:
|
| 301 |
-
# Update num features
|
| 302 |
-
fusion_input_features += self.history_len
|
| 303 |
-
|
| 304 |
-
if include_site_yield_history:
|
| 305 |
-
# Update num features
|
| 306 |
-
fusion_input_features += self.history_len + 1
|
| 307 |
-
|
| 308 |
-
self.output_network = output_network(
|
| 309 |
-
in_features=fusion_input_features,
|
| 310 |
-
out_features=self.num_output_features,
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
self.save_hyperparameters()
|
| 314 |
-
|
| 315 |
-
def forward(self, x):
|
| 316 |
-
"""Run model forward"""
|
| 317 |
-
|
| 318 |
-
if self.adapt_batches:
|
| 319 |
-
x = self._adapt_batch(x)
|
| 320 |
-
|
| 321 |
-
if self.use_id_embedding:
|
| 322 |
-
# eg: x['gsp_id] = [1] with location_id_mapping = {1:0}, would give [0]
|
| 323 |
-
id = torch.tensor(
|
| 324 |
-
[self.location_id_mapping[i.item()] for i in x[f"{self._target_key}_id"]],
|
| 325 |
-
device=self.device,
|
| 326 |
-
dtype=torch.int64,
|
| 327 |
-
)
|
| 328 |
-
|
| 329 |
-
modes = OrderedDict()
|
| 330 |
-
# ******************* Satellite imagery *************************
|
| 331 |
-
if self.include_sat:
|
| 332 |
-
# Shape: batch_size, seq_length, channel, height, width
|
| 333 |
-
sat_data = x["satellite_actual"][:, : self.sat_sequence_len]
|
| 334 |
-
sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels
|
| 335 |
-
|
| 336 |
-
if self.add_image_embedding_channel:
|
| 337 |
-
sat_data = self.sat_embed(sat_data, id)
|
| 338 |
-
modes["sat"] = self.sat_encoder(sat_data)
|
| 339 |
-
|
| 340 |
-
# *********************** NWP Data ************************************
|
| 341 |
-
if self.include_nwp:
|
| 342 |
-
# Loop through potentially many NMPs
|
| 343 |
-
for nwp_source in self.nwp_encoders_dict:
|
| 344 |
-
# shape: batch_size, seq_len, n_chans, height, width
|
| 345 |
-
nwp_data = x["nwp"][nwp_source]["nwp"].float()
|
| 346 |
-
nwp_data = torch.swapaxes(nwp_data, 1, 2) # switch time and channels
|
| 347 |
-
# Some NWP variables can overflow into NaNs when normalised if they have extreme
|
| 348 |
-
# tails
|
| 349 |
-
nwp_data = torch.clip(nwp_data, min=-50, max=50)
|
| 350 |
-
|
| 351 |
-
if self.add_image_embedding_channel:
|
| 352 |
-
nwp_data = self.nwp_embed_dict[nwp_source](nwp_data, id)
|
| 353 |
-
|
| 354 |
-
nwp_out = self.nwp_encoders_dict[nwp_source](nwp_data)
|
| 355 |
-
modes[f"nwp/{nwp_source}"] = nwp_out
|
| 356 |
-
|
| 357 |
-
# *********************** Site Data *************************************
|
| 358 |
-
# Add site-level yield history
|
| 359 |
-
if self.include_site_yield_history:
|
| 360 |
-
site_history = x["site"][:, : self.history_len + 1].float()
|
| 361 |
-
site_history = site_history.reshape(site_history.shape[0], -1)
|
| 362 |
-
modes["site"] = site_history
|
| 363 |
-
|
| 364 |
-
# Add site-level yield history through PV encoder
|
| 365 |
-
if self.include_pv:
|
| 366 |
-
if self._target_key != "site":
|
| 367 |
-
modes["site"] = self.pv_encoder(x)
|
| 368 |
-
else:
|
| 369 |
-
# Target is PV, so only take the history
|
| 370 |
-
# Copy batch
|
| 371 |
-
x_tmp = x.copy()
|
| 372 |
-
x_tmp["site"] = x_tmp["site"][:, : self.history_len + 1]
|
| 373 |
-
modes["site"] = self.pv_encoder(x_tmp)
|
| 374 |
-
|
| 375 |
-
# *********************** GSP Data ************************************
|
| 376 |
-
# add gsp yield history
|
| 377 |
-
if self.include_gsp_yield_history:
|
| 378 |
-
gsp_history = x["gsp"][:, : self.history_len].float()
|
| 379 |
-
gsp_history = gsp_history.reshape(gsp_history.shape[0], -1)
|
| 380 |
-
modes["gsp"] = gsp_history
|
| 381 |
-
|
| 382 |
-
# ********************** Embedding of GSP/Site ID ********************
|
| 383 |
-
if self.use_id_embedding:
|
| 384 |
-
modes["id"] = self.embed(id)
|
| 385 |
-
|
| 386 |
-
if self.include_sun:
|
| 387 |
-
# Use only new direct keys
|
| 388 |
-
sun = torch.cat(
|
| 389 |
-
(
|
| 390 |
-
x["solar_azimuth"],
|
| 391 |
-
x["solar_elevation"],
|
| 392 |
-
),
|
| 393 |
-
dim=1,
|
| 394 |
-
).float()
|
| 395 |
-
sun = self.sun_fc1(sun)
|
| 396 |
-
modes["sun"] = sun
|
| 397 |
-
|
| 398 |
-
if self.include_time:
|
| 399 |
-
time = torch.cat(
|
| 400 |
-
(
|
| 401 |
-
x[f"{self._target_key}_date_sin"],
|
| 402 |
-
x[f"{self._target_key}_date_cos"],
|
| 403 |
-
x[f"{self._target_key}_time_sin"],
|
| 404 |
-
x[f"{self._target_key}_time_cos"],
|
| 405 |
-
),
|
| 406 |
-
dim=1,
|
| 407 |
-
).float()
|
| 408 |
-
time = self.time_fc1(time)
|
| 409 |
-
modes["time"] = time
|
| 410 |
-
|
| 411 |
-
out = self.output_network(modes)
|
| 412 |
-
|
| 413 |
-
if self.use_quantile_regression:
|
| 414 |
-
# Shape: batch_size, seq_length * num_quantiles
|
| 415 |
-
out = out.reshape(out.shape[0], self.forecast_len, len(self.output_quantiles))
|
| 416 |
-
|
| 417 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/multimodal/readme.md
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
## Multimodal model architecture
|
| 2 |
-
|
| 3 |
-
These models fusion models to predict GSP power output based on NWP, non-HRV satellite, GSP output history, solor coordinates, and GSP ID.
|
| 4 |
-
|
| 5 |
-
The core model is `multimodel.Model`, and its architecture is shown in the diagram below.
|
| 6 |
-
|
| 7 |
-

|
| 8 |
-
|
| 9 |
-
This model uses encoders which take 4D (time, channel, x, y) inputs of NWP and satellite and encode them into 1D feature vectors. Different encoders are contained inside `encoders`.
|
| 10 |
-
|
| 11 |
-
Different choices for the fusion model are contained inside `linear_networks`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/multimodal/site_encoders/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
"""Submodels to encode site-level PV data"""
|
|
|
|
|
|
pvnet/models/multimodal/site_encoders/basic_blocks.py
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
"""Basic blocks for PV-site encoders"""
|
| 2 |
-
from abc import ABCMeta, abstractmethod
|
| 3 |
-
|
| 4 |
-
from torch import nn
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class AbstractSitesEncoder(nn.Module, metaclass=ABCMeta):
|
| 8 |
-
"""Abstract class for encoder for output data from multiple PV sites.
|
| 9 |
-
|
| 10 |
-
The encoder will take an input of shape (batch_size, sequence_length, num_sites)
|
| 11 |
-
and return an output of shape (batch_size, out_features).
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
def __init__(
|
| 15 |
-
self,
|
| 16 |
-
sequence_length: int,
|
| 17 |
-
num_sites: int,
|
| 18 |
-
out_features: int,
|
| 19 |
-
):
|
| 20 |
-
"""Abstract class for PV site-level encoder.
|
| 21 |
-
|
| 22 |
-
Args:
|
| 23 |
-
sequence_length: The time sequence length of the data.
|
| 24 |
-
num_sites: Number of PV sites in the input data.
|
| 25 |
-
out_features: Number of output features.
|
| 26 |
-
"""
|
| 27 |
-
super().__init__()
|
| 28 |
-
self.sequence_length = sequence_length
|
| 29 |
-
self.num_sites = num_sites
|
| 30 |
-
self.out_features = out_features
|
| 31 |
-
|
| 32 |
-
@abstractmethod
|
| 33 |
-
def forward(self):
|
| 34 |
-
"""Run model forward"""
|
| 35 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/multimodal/site_encoders/encoders.py
DELETED
|
@@ -1,284 +0,0 @@
|
|
| 1 |
-
"""Encoder modules for the site-level PV data.
|
| 2 |
-
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import einops
|
| 6 |
-
import torch
|
| 7 |
-
from torch import nn
|
| 8 |
-
|
| 9 |
-
from pvnet.models.multimodal.linear_networks.networks import ResFCNet2
|
| 10 |
-
from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractSitesEncoder
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class SimpleLearnedAggregator(AbstractSitesEncoder):
|
| 14 |
-
"""A simple model which learns a different weighted-average across all PV sites for each GSP.
|
| 15 |
-
|
| 16 |
-
Each sequence from each site is independently encodeded through some dense layers wih skip-
|
| 17 |
-
connections, then the encoded form of each sequence is aggregated through a learned weighted-sum
|
| 18 |
-
and finally put through more dense layers.
|
| 19 |
-
|
| 20 |
-
This model was written to be a simplified version of a single-headed attention layer.
|
| 21 |
-
"""
|
| 22 |
-
|
| 23 |
-
def __init__(
|
| 24 |
-
self,
|
| 25 |
-
sequence_length: int,
|
| 26 |
-
num_sites: int,
|
| 27 |
-
out_features: int,
|
| 28 |
-
value_dim: int = 10,
|
| 29 |
-
value_enc_resblocks: int = 2,
|
| 30 |
-
final_resblocks: int = 2,
|
| 31 |
-
):
|
| 32 |
-
"""A simple sequence encoder and weighted-average model.
|
| 33 |
-
|
| 34 |
-
Args:
|
| 35 |
-
sequence_length: The time sequence length of the data.
|
| 36 |
-
num_sites: Number of PV sites in the input data.
|
| 37 |
-
out_features: Number of output features.
|
| 38 |
-
value_dim: The number of features in each encoded sequence. Similar to the value
|
| 39 |
-
dimension in single- or multi-head attention.
|
| 40 |
-
value_dim: The number of features in each encoded sequence. Similar to the value
|
| 41 |
-
dimension in single- or multi-head attention.
|
| 42 |
-
value_enc_resblocks: Number of residual blocks in the value-encoder sub-network.
|
| 43 |
-
final_resblocks: Number of residual blocks in the final sub-network.
|
| 44 |
-
"""
|
| 45 |
-
|
| 46 |
-
super().__init__(sequence_length, num_sites, out_features)
|
| 47 |
-
|
| 48 |
-
# Network used to encode each PV site sequence
|
| 49 |
-
self._value_encoder = nn.Sequential(
|
| 50 |
-
ResFCNet2(
|
| 51 |
-
in_features=sequence_length,
|
| 52 |
-
out_features=value_dim,
|
| 53 |
-
fc_hidden_features=value_dim,
|
| 54 |
-
n_res_blocks=value_enc_resblocks,
|
| 55 |
-
res_block_layers=2,
|
| 56 |
-
dropout_frac=0,
|
| 57 |
-
),
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
# The learned weighted average is stored in an embedding layer for ease of use
|
| 61 |
-
self._attention_network = nn.Sequential(
|
| 62 |
-
nn.Embedding(318, num_sites),
|
| 63 |
-
nn.Softmax(dim=1),
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
# Network used to process weighted average
|
| 67 |
-
self.output_network = ResFCNet2(
|
| 68 |
-
in_features=value_dim,
|
| 69 |
-
out_features=out_features,
|
| 70 |
-
fc_hidden_features=value_dim,
|
| 71 |
-
n_res_blocks=final_resblocks,
|
| 72 |
-
res_block_layers=2,
|
| 73 |
-
dropout_frac=0,
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
def _calculate_attention(self, x):
|
| 77 |
-
gsp_ids = x["gsp_id"].squeeze().int()
|
| 78 |
-
attention = self._attention_network(gsp_ids)
|
| 79 |
-
return attention
|
| 80 |
-
|
| 81 |
-
def _encode_value(self, x):
|
| 82 |
-
# Shape: [batch size, sequence length, PV site]
|
| 83 |
-
pv_site_seqs = x["pv"].float()
|
| 84 |
-
batch_size = pv_site_seqs.shape[0]
|
| 85 |
-
|
| 86 |
-
pv_site_seqs = pv_site_seqs.swapaxes(1, 2).flatten(0, 1)
|
| 87 |
-
|
| 88 |
-
x_seq_enc = self._value_encoder(pv_site_seqs)
|
| 89 |
-
x_seq_out = x_seq_enc.unflatten(0, (batch_size, self.num_sites))
|
| 90 |
-
return x_seq_out
|
| 91 |
-
|
| 92 |
-
def forward(self, x):
|
| 93 |
-
"""Run model forward"""
|
| 94 |
-
# Output has shape: [batch size, num_sites, value_dim]
|
| 95 |
-
encodeded_seqs = self._encode_value(x)
|
| 96 |
-
|
| 97 |
-
# Calculate learned averaging weights
|
| 98 |
-
attn_avg_weights = self._calculate_attention(x)
|
| 99 |
-
|
| 100 |
-
# Take weighted average across num_sites
|
| 101 |
-
value_weighted_avg = (encodeded_seqs * attn_avg_weights.unsqueeze(-1)).sum(dim=1)
|
| 102 |
-
|
| 103 |
-
# Put through final processing layers
|
| 104 |
-
x_out = self.output_network(value_weighted_avg)
|
| 105 |
-
|
| 106 |
-
return x_out
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
class SingleAttentionNetwork(AbstractSitesEncoder):
|
| 110 |
-
"""A simple attention-based model with a single multihead attention layer
|
| 111 |
-
|
| 112 |
-
For the attention layer the query is based on the target alone, the key is based on the
|
| 113 |
-
input ID and the recent input data, the value is based on the recent input data.
|
| 114 |
-
|
| 115 |
-
"""
|
| 116 |
-
|
| 117 |
-
def __init__(
|
| 118 |
-
self,
|
| 119 |
-
sequence_length: int,
|
| 120 |
-
num_sites: int,
|
| 121 |
-
out_features: int,
|
| 122 |
-
kdim: int = 10,
|
| 123 |
-
id_embed_dim: int = 10,
|
| 124 |
-
num_heads: int = 2,
|
| 125 |
-
n_kv_res_blocks: int = 2,
|
| 126 |
-
kv_res_block_layers: int = 2,
|
| 127 |
-
use_id_in_value: bool = False,
|
| 128 |
-
target_id_dim: int = 318,
|
| 129 |
-
target_key_to_use: str = "gsp",
|
| 130 |
-
input_key_to_use: str = "site",
|
| 131 |
-
num_channels: int = 1,
|
| 132 |
-
num_sites_in_inference: int = 1,
|
| 133 |
-
):
|
| 134 |
-
"""A simple attention-based model with a single multihead attention layer
|
| 135 |
-
|
| 136 |
-
Args:
|
| 137 |
-
sequence_length: The time sequence length of the data.
|
| 138 |
-
num_sites: Number of sites in the input data.
|
| 139 |
-
out_features: Number of output features. In this network this is also the embed and
|
| 140 |
-
value dimension in the multi-head attention layer.
|
| 141 |
-
kdim: The dimensions used the keys.
|
| 142 |
-
id_embed_dim: Number of dimensiosn used in the site ID embedding layer(s).
|
| 143 |
-
num_heads: Number of parallel attention heads. Note that `out_features` will be split
|
| 144 |
-
across `num_heads` so `out_features` must be a multiple of `num_heads`.
|
| 145 |
-
n_kv_res_blocks: Number of residual blocks to use in the key and value encoders.
|
| 146 |
-
kv_res_block_layers: Number of fully-connected layers used in each residual block within
|
| 147 |
-
the key and value encoders.
|
| 148 |
-
use_id_in_value: Whether to use a site ID embedding in network used to produce the
|
| 149 |
-
value for the attention layer.
|
| 150 |
-
target_id_dim: The number of unique IDs.
|
| 151 |
-
target_key_to_use: The key to use for the target in the attention layer.
|
| 152 |
-
input_key_to_use: The key to use for the input in the attention layer.
|
| 153 |
-
num_channels: Number of channels in the input data. For single site generation,
|
| 154 |
-
this will be 1, as there is not channel dimension, for Sensors,
|
| 155 |
-
this will probably be higher than that
|
| 156 |
-
num_sites_in_inference: Number of sites to use in inference.
|
| 157 |
-
This is used to determine the number of sites to use in the
|
| 158 |
-
attention layer, for a single site, 1 works, while for multiple sites
|
| 159 |
-
(such as multiple sensors), this would be higher than that
|
| 160 |
-
|
| 161 |
-
"""
|
| 162 |
-
super().__init__(sequence_length, num_sites, out_features)
|
| 163 |
-
self.sequence_length = sequence_length
|
| 164 |
-
self.target_id_embedding = nn.Embedding(target_id_dim, out_features)
|
| 165 |
-
self.site_id_embedding = nn.Embedding(num_sites, id_embed_dim)
|
| 166 |
-
self._ids = nn.parameter.Parameter(torch.arange(num_sites), requires_grad=False)
|
| 167 |
-
self.use_id_in_value = use_id_in_value
|
| 168 |
-
self.target_key_to_use = target_key_to_use
|
| 169 |
-
self.input_key_to_use = input_key_to_use
|
| 170 |
-
self.num_channels = num_channels
|
| 171 |
-
self.num_sites_in_inference = num_sites_in_inference
|
| 172 |
-
|
| 173 |
-
if use_id_in_value:
|
| 174 |
-
self.value_id_embedding = nn.Embedding(num_sites, id_embed_dim)
|
| 175 |
-
|
| 176 |
-
self._value_encoder = nn.Sequential(
|
| 177 |
-
ResFCNet2(
|
| 178 |
-
in_features=sequence_length * self.num_channels
|
| 179 |
-
+ int(use_id_in_value) * id_embed_dim,
|
| 180 |
-
out_features=out_features,
|
| 181 |
-
fc_hidden_features=sequence_length * self.num_channels,
|
| 182 |
-
n_res_blocks=n_kv_res_blocks,
|
| 183 |
-
res_block_layers=kv_res_block_layers,
|
| 184 |
-
dropout_frac=0,
|
| 185 |
-
),
|
| 186 |
-
)
|
| 187 |
-
|
| 188 |
-
self._key_encoder = nn.Sequential(
|
| 189 |
-
ResFCNet2(
|
| 190 |
-
in_features=id_embed_dim + sequence_length * self.num_channels,
|
| 191 |
-
out_features=kdim,
|
| 192 |
-
fc_hidden_features=id_embed_dim + sequence_length * self.num_channels,
|
| 193 |
-
n_res_blocks=n_kv_res_blocks,
|
| 194 |
-
res_block_layers=kv_res_block_layers,
|
| 195 |
-
dropout_frac=0,
|
| 196 |
-
),
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
self.multihead_attn = nn.MultiheadAttention(
|
| 200 |
-
embed_dim=out_features,
|
| 201 |
-
kdim=kdim,
|
| 202 |
-
vdim=out_features,
|
| 203 |
-
num_heads=num_heads,
|
| 204 |
-
batch_first=True,
|
| 205 |
-
)
|
| 206 |
-
|
| 207 |
-
def _encode_inputs(self, x):
|
| 208 |
-
# Shape: [batch size, sequence length, number of sites]
|
| 209 |
-
# Shape: [batch size, station_id, sequence length, channels]
|
| 210 |
-
input_data = x[f"{self.input_key_to_use}"]
|
| 211 |
-
if len(input_data.shape) == 2: # one site per sample
|
| 212 |
-
input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D
|
| 213 |
-
if len(input_data.shape) == 4: # Has multiple channels
|
| 214 |
-
input_data = input_data[:, :, : self.sequence_length]
|
| 215 |
-
input_data = einops.rearrange(input_data, "b id s c -> b (s c) id")
|
| 216 |
-
else:
|
| 217 |
-
input_data = input_data[:, : self.sequence_length]
|
| 218 |
-
site_seqs = input_data.float()
|
| 219 |
-
batch_size = site_seqs.shape[0]
|
| 220 |
-
site_seqs = site_seqs.swapaxes(1, 2) # [batch size, Site ID, sequence length]
|
| 221 |
-
return site_seqs, batch_size
|
| 222 |
-
|
| 223 |
-
def _encode_query(self, x):
|
| 224 |
-
# Select the first one
|
| 225 |
-
if self.target_key_to_use == "gsp":
|
| 226 |
-
# GSP seems to have a different structure
|
| 227 |
-
ids = x[f"{self.target_key_to_use}_id"]
|
| 228 |
-
else:
|
| 229 |
-
ids = x[f"{self.input_key_to_use}_id"]
|
| 230 |
-
ids = ids.int()
|
| 231 |
-
query = self.target_id_embedding(ids).unsqueeze(1)
|
| 232 |
-
return query
|
| 233 |
-
|
| 234 |
-
def _encode_key(self, x):
|
| 235 |
-
site_seqs, batch_size = self._encode_inputs(x)
|
| 236 |
-
|
| 237 |
-
# site ID embeddings are the same for each sample
|
| 238 |
-
site_id_embed = torch.tile(self.site_id_embedding(self._ids), (batch_size, 1, 1))
|
| 239 |
-
# Each concated (site sequence, site ID embedding) is processed with encoder
|
| 240 |
-
x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1)
|
| 241 |
-
key = self._key_encoder(x_seq_in)
|
| 242 |
-
|
| 243 |
-
# Reshape to [batch size, site, kdim]
|
| 244 |
-
key = key.unflatten(0, (batch_size, self.num_sites))
|
| 245 |
-
return key
|
| 246 |
-
|
| 247 |
-
def _encode_value(self, x):
|
| 248 |
-
site_seqs, batch_size = self._encode_inputs(x)
|
| 249 |
-
|
| 250 |
-
if self.use_id_in_value:
|
| 251 |
-
# site ID embeddings are the same for each sample
|
| 252 |
-
site_id_embed = torch.tile(self.value_id_embedding(self._ids), (batch_size, 1, 1))
|
| 253 |
-
# Each concated (site sequence, site ID embedding) is processed with encoder
|
| 254 |
-
x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1)
|
| 255 |
-
else:
|
| 256 |
-
# Encode each site sequence independently
|
| 257 |
-
x_seq_in = site_seqs.flatten(0, 1)
|
| 258 |
-
value = self._value_encoder(x_seq_in)
|
| 259 |
-
|
| 260 |
-
# Reshape to [batch size, site, vdim]
|
| 261 |
-
value = value.unflatten(0, (batch_size, self.num_sites))
|
| 262 |
-
return value
|
| 263 |
-
|
| 264 |
-
def _attention_forward(self, x, average_attn_weights=True):
|
| 265 |
-
query = self._encode_query(x)
|
| 266 |
-
key = self._encode_key(x)
|
| 267 |
-
value = self._encode_value(x)
|
| 268 |
-
attn_output, attn_weights = self.multihead_attn(
|
| 269 |
-
query, key, value, average_attn_weights=average_attn_weights
|
| 270 |
-
)
|
| 271 |
-
|
| 272 |
-
return attn_output, attn_weights
|
| 273 |
-
|
| 274 |
-
def forward(self, x):
|
| 275 |
-
"""Run model forward"""
|
| 276 |
-
# Do slicing here to only get history
|
| 277 |
-
attn_output, attn_output_weights = self._attention_forward(x)
|
| 278 |
-
|
| 279 |
-
# Reshape from [batch_size, 1, vdim] to [batch_size, vdim]
|
| 280 |
-
x_out = attn_output.squeeze()
|
| 281 |
-
if len(x_out.shape) == 1:
|
| 282 |
-
x_out = x_out.unsqueeze(0)
|
| 283 |
-
|
| 284 |
-
return x_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/multimodal/unimodal_teacher.py
DELETED
|
@@ -1,447 +0,0 @@
|
|
| 1 |
-
"""The default composite model architecture for PVNet"""
|
| 2 |
-
|
| 3 |
-
import glob
|
| 4 |
-
from collections import OrderedDict
|
| 5 |
-
from typing import Any, Optional
|
| 6 |
-
|
| 7 |
-
import hydra
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
-
from pyaml_env import parse_config
|
| 11 |
-
from torch import nn
|
| 12 |
-
|
| 13 |
-
import pvnet
|
| 14 |
-
from pvnet.models.base_model import BaseModel
|
| 15 |
-
from pvnet.models.multimodal.linear_networks.basic_blocks import AbstractLinearNetwork
|
| 16 |
-
from pvnet.optimizers import AbstractOptimizer
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class Model(BaseModel):
|
| 20 |
-
"""Neural network which combines information from different sources
|
| 21 |
-
|
| 22 |
-
The network is trained via unimodal teachers [1].
|
| 23 |
-
|
| 24 |
-
Architecture is roughly as follows:
|
| 25 |
-
|
| 26 |
-
- Satellite data, if included, is put through an encoder which transforms it from 4D, with time,
|
| 27 |
-
channel, height, and width dimensions to become a 1D feature vector.
|
| 28 |
-
- NWP, if included, is put through a similar encoder.
|
| 29 |
-
- PV site-level data, if included, is put through an encoder which transforms it from 2D, with
|
| 30 |
-
time and system-ID dimensions, to become a 1D feature vector.
|
| 31 |
-
- The satellite features*, NWP features*, PV site-level features*, GSP ID embedding*, and sun
|
| 32 |
-
paramters* are concatenated into a 1D feature vector and passed through another neural
|
| 33 |
-
network to combine them and produce a forecast.
|
| 34 |
-
|
| 35 |
-
* if included
|
| 36 |
-
[1] https://arxiv.org/pdf/2305.01233.pdf
|
| 37 |
-
"""
|
| 38 |
-
|
| 39 |
-
name = "unimodal_teacher"
|
| 40 |
-
|
| 41 |
-
def __init__(
|
| 42 |
-
self,
|
| 43 |
-
output_network: AbstractLinearNetwork,
|
| 44 |
-
output_quantiles: Optional[list[float]] = None,
|
| 45 |
-
include_gsp_yield_history: bool = True,
|
| 46 |
-
include_sun: bool = True,
|
| 47 |
-
location_id_mapping: Optional[dict[Any, int]] = None,
|
| 48 |
-
embedding_dim: Optional[int] = 16,
|
| 49 |
-
forecast_minutes: int = 30,
|
| 50 |
-
history_minutes: int = 60,
|
| 51 |
-
optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
|
| 52 |
-
mode_teacher_dict: dict = {},
|
| 53 |
-
val_best: bool = True,
|
| 54 |
-
cold_start: bool = True,
|
| 55 |
-
enc_loss_frac: float = 0.3,
|
| 56 |
-
adapt_batches: Optional[bool] = False,
|
| 57 |
-
):
|
| 58 |
-
"""Neural network which combines information from different sources.
|
| 59 |
-
|
| 60 |
-
The network is trained via unimodal teachers [1].
|
| 61 |
-
|
| 62 |
-
[1] https://arxiv.org/pdf/2305.01233.pdf
|
| 63 |
-
|
| 64 |
-
Notes:
|
| 65 |
-
In the args, where it says a module `m` is partially instantiated, it means that a
|
| 66 |
-
normal pytorch module will be returned by running `mod = m(**kwargs)`. In this library,
|
| 67 |
-
this partial instantiation is generally achieved using partial instantiation via hydra.
|
| 68 |
-
However, the arg is still valid as long as `m(**kwargs)` returns a valid pytorch module
|
| 69 |
-
- for example if `m` is a regular function.
|
| 70 |
-
|
| 71 |
-
Args:
|
| 72 |
-
output_network: A partially instatiated pytorch Module class used to combine the 1D
|
| 73 |
-
features to produce the forecast.
|
| 74 |
-
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
|
| 75 |
-
None the output is a single value.
|
| 76 |
-
include_gsp_yield_history: Include GSP yield data.
|
| 77 |
-
include_sun: Include sun azimuth and altitude data.
|
| 78 |
-
location_id_mapping: A dictionary mapping the location ID to an integer. ID embedding is
|
| 79 |
-
not used if this is not provided.
|
| 80 |
-
embedding_dim: Number of embedding dimensions to use for GSP ID
|
| 81 |
-
forecast_minutes: The amount of minutes that should be forecasted.
|
| 82 |
-
history_minutes: The default amount of historical minutes that are used.
|
| 83 |
-
optimizer: Optimizer factory function used for network.
|
| 84 |
-
mode_teacher_dict: A dictionary of paths to different model checkpoint directories,
|
| 85 |
-
which will be used as the unimodal teachers.
|
| 86 |
-
val_best: Whether to load the model which performed best on the validation set. Else the
|
| 87 |
-
last checkpoint is loaded.
|
| 88 |
-
cold_start: Whether to train the uni-modal encoders from scratch. Else start them with
|
| 89 |
-
weights from the uni-modal teachers.
|
| 90 |
-
enc_loss_frac: Fraction of total loss attributed to the teacher encoders.
|
| 91 |
-
adapt_batches: If set to true, we attempt to slice the batches to the expected shape for
|
| 92 |
-
the model to use. This allows us to overprepare batches and slice from them for the
|
| 93 |
-
data we need for a model run.
|
| 94 |
-
"""
|
| 95 |
-
|
| 96 |
-
self.include_gsp_yield_history = include_gsp_yield_history
|
| 97 |
-
self.include_sun = include_sun
|
| 98 |
-
self.location_id_mapping = location_id_mapping
|
| 99 |
-
self.embedding_dim = embedding_dim
|
| 100 |
-
self.enc_loss_frac = enc_loss_frac
|
| 101 |
-
self.include_sat = False
|
| 102 |
-
self.include_nwp = False
|
| 103 |
-
self.include_pv = False
|
| 104 |
-
self.adapt_batches = adapt_batches
|
| 105 |
-
|
| 106 |
-
self.use_id_embedding = location_id_mapping is not None
|
| 107 |
-
|
| 108 |
-
if self.use_id_embedding:
|
| 109 |
-
num_embeddings = max(location_id_mapping.values()) + 1
|
| 110 |
-
|
| 111 |
-
# This is set but modified later based on the teachers
|
| 112 |
-
self.add_image_embedding_channel = False
|
| 113 |
-
|
| 114 |
-
super().__init__(
|
| 115 |
-
history_minutes=history_minutes,
|
| 116 |
-
forecast_minutes=forecast_minutes,
|
| 117 |
-
optimizer=optimizer,
|
| 118 |
-
output_quantiles=output_quantiles,
|
| 119 |
-
target_key="gsp",
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
# Number of features expected by the output_network
|
| 123 |
-
# Add to this as network pices are constructed
|
| 124 |
-
fusion_input_features = 0
|
| 125 |
-
|
| 126 |
-
self.teacher_models = torch.nn.ModuleDict()
|
| 127 |
-
self.mode_teacher_dict = mode_teacher_dict
|
| 128 |
-
|
| 129 |
-
for mode, path in mode_teacher_dict.items():
|
| 130 |
-
# load teacher model and freeze its weights
|
| 131 |
-
self.teacher_models[mode] = self.get_unimodal_encoder(path, True, val_best=val_best)
|
| 132 |
-
|
| 133 |
-
for param in self.teacher_models[mode].parameters():
|
| 134 |
-
param.requires_grad = False
|
| 135 |
-
|
| 136 |
-
# Recreate model as student
|
| 137 |
-
mode_student_model = self.get_unimodal_encoder(
|
| 138 |
-
path, load_weights=(not cold_start), val_best=val_best
|
| 139 |
-
)
|
| 140 |
-
|
| 141 |
-
if mode == "sat":
|
| 142 |
-
self.include_sat = True
|
| 143 |
-
self.sat_sequence_len = mode_student_model.sat_sequence_len
|
| 144 |
-
self.sat_encoder = mode_student_model.sat_encoder
|
| 145 |
-
|
| 146 |
-
if mode_student_model.add_image_embedding_channel:
|
| 147 |
-
self.sat_embed = mode_student_model.sat_embed
|
| 148 |
-
self.add_image_embedding_channel = True
|
| 149 |
-
|
| 150 |
-
fusion_input_features += self.sat_encoder.out_features
|
| 151 |
-
|
| 152 |
-
elif mode == "site":
|
| 153 |
-
self.include_pv = True
|
| 154 |
-
self.site_encoder = mode_student_model.site_encoder
|
| 155 |
-
fusion_input_features += self.site_encoder.out_features
|
| 156 |
-
|
| 157 |
-
elif mode.startswith("nwp"):
|
| 158 |
-
nwp_source = mode.removeprefix("nwp/")
|
| 159 |
-
|
| 160 |
-
if not self.include_nwp:
|
| 161 |
-
self.include_nwp = True
|
| 162 |
-
self.nwp_encoders_dict = torch.nn.ModuleDict()
|
| 163 |
-
|
| 164 |
-
if mode_student_model.add_image_embedding_channel:
|
| 165 |
-
self.add_image_embedding_channel = True
|
| 166 |
-
self.nwp_embed_dict = torch.nn.ModuleDict()
|
| 167 |
-
|
| 168 |
-
self.nwp_encoders_dict[nwp_source] = mode_student_model.nwp_encoders_dict[
|
| 169 |
-
nwp_source
|
| 170 |
-
]
|
| 171 |
-
|
| 172 |
-
if self.add_image_embedding_channel:
|
| 173 |
-
self.nwp_embed_dict[nwp_source] = mode_student_model.nwp_embed_dict[nwp_source]
|
| 174 |
-
|
| 175 |
-
fusion_input_features += self.nwp_encoders_dict[nwp_source].out_features
|
| 176 |
-
|
| 177 |
-
if self.embedding_dim:
|
| 178 |
-
self.embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
| 179 |
-
fusion_input_features += embedding_dim
|
| 180 |
-
|
| 181 |
-
if self.include_sun:
|
| 182 |
-
self.sun_fc1 = nn.Linear(
|
| 183 |
-
in_features=2 * (self.forecast_len + self.history_len + 1),
|
| 184 |
-
out_features=16,
|
| 185 |
-
)
|
| 186 |
-
fusion_input_features += 16
|
| 187 |
-
|
| 188 |
-
if include_gsp_yield_history:
|
| 189 |
-
fusion_input_features += self.history_len
|
| 190 |
-
|
| 191 |
-
self.output_network = output_network(
|
| 192 |
-
in_features=fusion_input_features,
|
| 193 |
-
out_features=self.num_output_features,
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
-
self.save_hyperparameters()
|
| 197 |
-
|
| 198 |
-
def get_unimodal_encoder(self, path, load_weights, val_best):
|
| 199 |
-
"""Load a model to function as a unimodal teacher"""
|
| 200 |
-
|
| 201 |
-
model_config = parse_config(f"{path}/model_config.yaml")
|
| 202 |
-
|
| 203 |
-
# Load the teacher model
|
| 204 |
-
encoder = hydra.utils.instantiate(model_config)
|
| 205 |
-
|
| 206 |
-
if load_weights:
|
| 207 |
-
if val_best:
|
| 208 |
-
# Only one epoch (best) saved per model
|
| 209 |
-
files = glob.glob(f"{path}/epoch*.ckpt")
|
| 210 |
-
assert len(files) == 1
|
| 211 |
-
checkpoint = torch.load(files[0], map_location="cpu")
|
| 212 |
-
else:
|
| 213 |
-
checkpoint = torch.load(f"{path}/last.ckpt", map_location="cpu")
|
| 214 |
-
|
| 215 |
-
encoder.load_state_dict(state_dict=checkpoint["state_dict"])
|
| 216 |
-
return encoder
|
| 217 |
-
|
| 218 |
-
def teacher_forward(self, x):
|
| 219 |
-
"""Run the teacher models and return their encodings"""
|
| 220 |
-
|
| 221 |
-
if self.use_id_embedding:
|
| 222 |
-
# eg: x['gsp_id] = [1] with location_id_mapping = {1:0}, would give [0]
|
| 223 |
-
id = torch.tensor(
|
| 224 |
-
[self.location_id_mapping[i.item()] for i in x[f"{self._target_key}_id"]],
|
| 225 |
-
device=self.device,
|
| 226 |
-
dtype=torch.int64,
|
| 227 |
-
)
|
| 228 |
-
|
| 229 |
-
modes = OrderedDict()
|
| 230 |
-
for mode, teacher_model in self.teacher_models.items():
|
| 231 |
-
# ******************* Satellite imagery *************************
|
| 232 |
-
if mode == "sat":
|
| 233 |
-
# Shape: batch_size, seq_length, channel, height, width
|
| 234 |
-
sat_data = x["satellite_actual"][:, : teacher_model.sat_sequence_len]
|
| 235 |
-
sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels
|
| 236 |
-
|
| 237 |
-
if self.add_image_embedding_channel:
|
| 238 |
-
sat_data = teacher_model.sat_embed(sat_data, id)
|
| 239 |
-
|
| 240 |
-
modes[mode] = teacher_model.sat_encoder(sat_data)
|
| 241 |
-
|
| 242 |
-
# *********************** NWP Data ************************************
|
| 243 |
-
if mode.startswith("nwp"):
|
| 244 |
-
nwp_source = mode.removeprefix("nwp/")
|
| 245 |
-
|
| 246 |
-
# shape: batch_size, seq_len, n_chans, height, width
|
| 247 |
-
nwp_data = x["nwp"][nwp_source]["nwp"].float()
|
| 248 |
-
nwp_data = torch.swapaxes(nwp_data, 1, 2) # switch time and channels
|
| 249 |
-
nwp_data = torch.clip(nwp_data, min=-50, max=50)
|
| 250 |
-
if teacher_model.add_image_embedding_channel:
|
| 251 |
-
nwp_data = teacher_model.nwp_embed_dict[nwp_source](nwp_data, id)
|
| 252 |
-
|
| 253 |
-
nwp_out = teacher_model.nwp_encoders_dict[nwp_source](nwp_data)
|
| 254 |
-
modes[mode] = nwp_out
|
| 255 |
-
|
| 256 |
-
# *********************** PV Data *************************************
|
| 257 |
-
# Add site-level PV yield
|
| 258 |
-
if mode == "site":
|
| 259 |
-
modes[mode] = teacher_model.site_encoder(x)
|
| 260 |
-
|
| 261 |
-
return modes
|
| 262 |
-
|
| 263 |
-
def forward(self, x, return_modes=False):
|
| 264 |
-
"""Run model forward"""
|
| 265 |
-
|
| 266 |
-
if self.adapt_batches:
|
| 267 |
-
x = self._adapt_batch(x)
|
| 268 |
-
|
| 269 |
-
if self.use_id_embedding:
|
| 270 |
-
# eg: x['gsp_id] = [1] with location_id_mapping = {1:0}, would give [0]
|
| 271 |
-
id = torch.tensor(
|
| 272 |
-
[self.location_id_mapping[i.item()] for i in x[f"{self._target_key}_id"]],
|
| 273 |
-
device=self.device,
|
| 274 |
-
dtype=torch.int64,
|
| 275 |
-
)
|
| 276 |
-
|
| 277 |
-
modes = OrderedDict()
|
| 278 |
-
# ******************* Satellite imagery *************************
|
| 279 |
-
if self.include_sat:
|
| 280 |
-
# Shape: batch_size, seq_length, channel, height, width
|
| 281 |
-
sat_data = x["satellite_actual"][:, : self.sat_sequence_len]
|
| 282 |
-
sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels
|
| 283 |
-
|
| 284 |
-
if self.add_image_embedding_channel:
|
| 285 |
-
sat_data = self.sat_embed(sat_data, id)
|
| 286 |
-
modes["sat"] = self.sat_encoder(sat_data)
|
| 287 |
-
|
| 288 |
-
# *********************** NWP Data ************************************
|
| 289 |
-
if self.include_nwp:
|
| 290 |
-
# Loop through potentially many NMPs
|
| 291 |
-
for nwp_source in self.nwp_encoders_dict:
|
| 292 |
-
# shape: batch_size, seq_len, n_chans, height, width
|
| 293 |
-
nwp_data = x["nwp"][nwp_source]["nwp"].float()
|
| 294 |
-
nwp_data = torch.swapaxes(nwp_data, 1, 2) # switch time and channels
|
| 295 |
-
# Some NWP variables can overflow into NaNs when normalised if they have extreme
|
| 296 |
-
# tails
|
| 297 |
-
nwp_data = torch.clip(nwp_data, min=-50, max=50)
|
| 298 |
-
|
| 299 |
-
if self.add_image_embedding_channel:
|
| 300 |
-
nwp_data = self.nwp_embed_dict[nwp_source](nwp_data, id)
|
| 301 |
-
|
| 302 |
-
nwp_out = self.nwp_encoders_dict[nwp_source](nwp_data)
|
| 303 |
-
modes[f"nwp/{nwp_source}"] = nwp_out
|
| 304 |
-
|
| 305 |
-
# *********************** PV Data *************************************
|
| 306 |
-
# Add site-level PV yield
|
| 307 |
-
if self.include_pv:
|
| 308 |
-
if self._target_key != "site":
|
| 309 |
-
modes["site"] = self.site_encoder(x)
|
| 310 |
-
else:
|
| 311 |
-
# Target is PV, so only take the history
|
| 312 |
-
pv_history = x["pv"][:, : self.history_len].float()
|
| 313 |
-
modes["site"] = self.site_encoder(pv_history)
|
| 314 |
-
|
| 315 |
-
# *********************** GSP Data ************************************
|
| 316 |
-
# add gsp yield history
|
| 317 |
-
if self.include_gsp_yield_history:
|
| 318 |
-
gsp_history = x["gsp"][:, : self.history_len].float()
|
| 319 |
-
gsp_history = gsp_history.reshape(gsp_history.shape[0], -1)
|
| 320 |
-
modes["gsp"] = gsp_history
|
| 321 |
-
|
| 322 |
-
# ********************** Embedding of GSP ID ********************
|
| 323 |
-
if self.use_id_embedding:
|
| 324 |
-
modes["id"] = self.embed(id)
|
| 325 |
-
|
| 326 |
-
if self.include_sun:
|
| 327 |
-
# Use only new direct keys
|
| 328 |
-
sun = torch.cat(
|
| 329 |
-
(
|
| 330 |
-
x["solar_azimuth"],
|
| 331 |
-
x["solar_elevation"],
|
| 332 |
-
),
|
| 333 |
-
dim=1,
|
| 334 |
-
).float()
|
| 335 |
-
sun = self.sun_fc1(sun)
|
| 336 |
-
modes["sun"] = sun
|
| 337 |
-
|
| 338 |
-
out = self.output_network(modes)
|
| 339 |
-
|
| 340 |
-
if self.use_quantile_regression:
|
| 341 |
-
# Shape: batch_size, seq_length * num_quantiles
|
| 342 |
-
out = out.reshape(out.shape[0], self.forecast_len, len(self.output_quantiles))
|
| 343 |
-
|
| 344 |
-
if return_modes:
|
| 345 |
-
return out, modes
|
| 346 |
-
else:
|
| 347 |
-
return out
|
| 348 |
-
|
| 349 |
-
def _calculate_teacher_loss(self, modes, teacher_modes):
|
| 350 |
-
enc_losses = {}
|
| 351 |
-
for m, enc in teacher_modes.items():
|
| 352 |
-
enc_losses[f"enc_loss/{m}"] = F.l1_loss(enc, modes[m])
|
| 353 |
-
enc_losses["enc_loss/total"] = sum([v for k, v in enc_losses.items()])
|
| 354 |
-
return enc_losses
|
| 355 |
-
|
| 356 |
-
def training_step(self, batch, batch_idx):
|
| 357 |
-
"""Run training step"""
|
| 358 |
-
y_hat, modes = self.forward(batch, return_modes=True)
|
| 359 |
-
y = batch[self._target_key][:, -self.forecast_len :, 0]
|
| 360 |
-
|
| 361 |
-
losses = self._calculate_common_losses(y, y_hat)
|
| 362 |
-
|
| 363 |
-
teacher_modes = self.teacher_forward(batch)
|
| 364 |
-
teacher_loss = self._calculate_teacher_loss(modes, teacher_modes)
|
| 365 |
-
losses.update(teacher_loss)
|
| 366 |
-
|
| 367 |
-
if self.use_quantile_regression:
|
| 368 |
-
opt_target = losses["quantile_loss"]
|
| 369 |
-
else:
|
| 370 |
-
opt_target = losses["MAE"]
|
| 371 |
-
|
| 372 |
-
t_loss = teacher_loss["enc_loss/total"]
|
| 373 |
-
|
| 374 |
-
# The scales of the two losses
|
| 375 |
-
l_s = opt_target.detach()
|
| 376 |
-
tl_s = max(t_loss.detach(), 1e-9)
|
| 377 |
-
|
| 378 |
-
# opt_target = t_loss/tl_s * l_s * self.enc_loss_frac + opt_target * (1-self.enc_loss_frac)
|
| 379 |
-
losses["opt_loss"] = t_loss / tl_s * l_s * self.enc_loss_frac + opt_target * (
|
| 380 |
-
1 - self.enc_loss_frac
|
| 381 |
-
)
|
| 382 |
-
|
| 383 |
-
losses = {f"{k}/train": v for k, v in losses.items()}
|
| 384 |
-
self._training_accumulate_log(batch, batch_idx, losses, y_hat)
|
| 385 |
-
|
| 386 |
-
return losses["opt_loss/train"]
|
| 387 |
-
|
| 388 |
-
def convert_to_multimodal_model(self, config):
|
| 389 |
-
"""Convert the model into a multimodal model class whilst preserving weights"""
|
| 390 |
-
config = config.copy()
|
| 391 |
-
|
| 392 |
-
if "cold_start" in config:
|
| 393 |
-
del config["cold_start"]
|
| 394 |
-
|
| 395 |
-
config["_target_"] = "pvnet.models.multimodal.multimodal.Model"
|
| 396 |
-
|
| 397 |
-
sources = []
|
| 398 |
-
for mode, path in config["mode_teacher_dict"].items():
|
| 399 |
-
model_config = parse_config(f"{path}/model_config.yaml")
|
| 400 |
-
|
| 401 |
-
if mode.startswith("nwp"):
|
| 402 |
-
nwp_source = mode.removeprefix("nwp/")
|
| 403 |
-
if "nwp_encoders_dict" in config:
|
| 404 |
-
for key in ["nwp_encoders_dict", "nwp_history_minutes", "nwp_forecast_minutes"]:
|
| 405 |
-
config[key][nwp_source] = model_config[key][nwp_source]
|
| 406 |
-
sources.append("nwp")
|
| 407 |
-
else:
|
| 408 |
-
for key in ["nwp_encoders_dict", "nwp_history_minutes", "nwp_forecast_minutes"]:
|
| 409 |
-
config[key] = {nwp_source: model_config[key][nwp_source]}
|
| 410 |
-
config["add_image_embedding_channel"] = model_config["add_image_embedding_channel"]
|
| 411 |
-
|
| 412 |
-
elif mode == "sat":
|
| 413 |
-
for key in [
|
| 414 |
-
"sat_encoder",
|
| 415 |
-
"add_image_embedding_channel",
|
| 416 |
-
"min_sat_delay_minutes",
|
| 417 |
-
"sat_history_minutes",
|
| 418 |
-
]:
|
| 419 |
-
config[key] = model_config[key]
|
| 420 |
-
sources.append("sat")
|
| 421 |
-
|
| 422 |
-
elif mode == "site":
|
| 423 |
-
for key in ["site_encoder", "site_history_minutes"]:
|
| 424 |
-
config[key] = model_config[key]
|
| 425 |
-
sources.append("site")
|
| 426 |
-
|
| 427 |
-
del config["mode_teacher_dict"]
|
| 428 |
-
|
| 429 |
-
# Load the teacher model
|
| 430 |
-
multimodal_model = hydra.utils.instantiate(config)
|
| 431 |
-
|
| 432 |
-
if "sat" in sources:
|
| 433 |
-
multimodal_model.sat_encoder.load_state_dict(self.sat_encoder.state_dict())
|
| 434 |
-
if "nwp" in sources:
|
| 435 |
-
multimodal_model.nwp_encoders_dict.load_state_dict(self.nwp_encoders_dict.state_dict())
|
| 436 |
-
if "site" in sources:
|
| 437 |
-
multimodal_model.site_encoder.load_state_dict(self.site_encoder.state_dict())
|
| 438 |
-
|
| 439 |
-
multimodal_model.output_network.load_state_dict(self.output_network.state_dict())
|
| 440 |
-
|
| 441 |
-
if self.embedding_dim:
|
| 442 |
-
multimodal_model.embed.load_state_dict(self.embed.state_dict())
|
| 443 |
-
|
| 444 |
-
if self.include_sun:
|
| 445 |
-
multimodal_model.sun_fc1.load_state_dict(self.sun_fc1.state_dict())
|
| 446 |
-
|
| 447 |
-
return multimodal_model, config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/models/utils.py
DELETED
|
@@ -1,123 +0,0 @@
|
|
| 1 |
-
"""Utility functions"""
|
| 2 |
-
|
| 3 |
-
import logging
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch
|
| 7 |
-
|
| 8 |
-
logger = logging.getLogger(__name__)
|
| 9 |
-
|
| 10 |
-
logger = logging.getLogger(__name__)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class PredAccumulator:
|
| 14 |
-
"""A class for accumulating y-predictions using grad accumulation and small batch size.
|
| 15 |
-
|
| 16 |
-
Attributes:
|
| 17 |
-
_y_hats (list[torch.Tensor]): List of prediction tensors
|
| 18 |
-
"""
|
| 19 |
-
|
| 20 |
-
def __init__(self):
|
| 21 |
-
"""Prediction accumulator"""
|
| 22 |
-
self._y_hats = []
|
| 23 |
-
|
| 24 |
-
def __bool__(self):
|
| 25 |
-
return len(self._y_hats) > 0
|
| 26 |
-
|
| 27 |
-
def append(self, y_hat: torch.Tensor):
|
| 28 |
-
"""Append a sub-batch of predictions"""
|
| 29 |
-
self._y_hats.append(y_hat)
|
| 30 |
-
|
| 31 |
-
def flush(self) -> torch.Tensor:
|
| 32 |
-
"""Return all appended predictions as single tensor and remove from accumulated store."""
|
| 33 |
-
y_hat = torch.cat(self._y_hats, dim=0)
|
| 34 |
-
self._y_hats = []
|
| 35 |
-
return y_hat
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
class DictListAccumulator:
|
| 39 |
-
"""Abstract class for accumulating dictionaries of lists"""
|
| 40 |
-
|
| 41 |
-
@staticmethod
|
| 42 |
-
def _dict_list_append(d1, d2):
|
| 43 |
-
for k, v in d2.items():
|
| 44 |
-
d1[k].append(v)
|
| 45 |
-
|
| 46 |
-
@staticmethod
|
| 47 |
-
def _dict_init_list(d):
|
| 48 |
-
return {k: [v] for k, v in d.items()}
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
class MetricAccumulator(DictListAccumulator):
|
| 52 |
-
"""Dictionary of metrics accumulator.
|
| 53 |
-
|
| 54 |
-
A class for accumulating, and finding the mean of logging metrics when using grad
|
| 55 |
-
accumulation and the batch size is small.
|
| 56 |
-
|
| 57 |
-
Attributes:
|
| 58 |
-
_metrics (Dict[str, list[float]]): Dictionary containing lists of metrics.
|
| 59 |
-
"""
|
| 60 |
-
|
| 61 |
-
def __init__(self):
|
| 62 |
-
"""Dictionary of metrics accumulator."""
|
| 63 |
-
self._metrics = {}
|
| 64 |
-
|
| 65 |
-
def __bool__(self):
|
| 66 |
-
return self._metrics != {}
|
| 67 |
-
|
| 68 |
-
def append(self, loss_dict: dict[str, float]):
|
| 69 |
-
"""Append lictionary of metrics to self"""
|
| 70 |
-
if not self:
|
| 71 |
-
self._metrics = self._dict_init_list(loss_dict)
|
| 72 |
-
else:
|
| 73 |
-
self._dict_list_append(self._metrics, loss_dict)
|
| 74 |
-
|
| 75 |
-
def flush(self) -> dict[str, float]:
|
| 76 |
-
"""Calculate mean of all accumulated metrics and clear"""
|
| 77 |
-
mean_metrics = {k: np.mean(v) for k, v in self._metrics.items()}
|
| 78 |
-
self._metrics = {}
|
| 79 |
-
return mean_metrics
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
class BatchAccumulator(DictListAccumulator):
|
| 83 |
-
"""A class for accumulating batches when using grad accumulation and the batch size is small.
|
| 84 |
-
|
| 85 |
-
Attributes:
|
| 86 |
-
_batches (Dict[str, list[torch.Tensor]]): Dictionary containing lists of metrics.
|
| 87 |
-
"""
|
| 88 |
-
|
| 89 |
-
def __init__(self, key_to_keep: str = "gsp"):
|
| 90 |
-
"""Batch accumulator"""
|
| 91 |
-
self._batches = {}
|
| 92 |
-
self.key_to_keep = key_to_keep
|
| 93 |
-
|
| 94 |
-
def __bool__(self):
|
| 95 |
-
return self._batches != {}
|
| 96 |
-
|
| 97 |
-
# @staticmethod
|
| 98 |
-
def _filter_batch_dict(self, d):
|
| 99 |
-
keep_keys = [
|
| 100 |
-
self.key_to_keep,
|
| 101 |
-
f"{self.key_to_keep}_id",
|
| 102 |
-
f"{self.key_to_keep}_t0_idx",
|
| 103 |
-
f"{self.key_to_keep}_time_utc",
|
| 104 |
-
]
|
| 105 |
-
return {k: v for k, v in d.items() if k in keep_keys}
|
| 106 |
-
|
| 107 |
-
def append(self, batch: dict[str, list[torch.Tensor]]):
|
| 108 |
-
"""Append batch to self"""
|
| 109 |
-
if not self:
|
| 110 |
-
self._batches = self._dict_init_list(self._filter_batch_dict(batch))
|
| 111 |
-
else:
|
| 112 |
-
self._dict_list_append(self._batches, self._filter_batch_dict(batch))
|
| 113 |
-
|
| 114 |
-
def flush(self) -> dict[str, list[torch.Tensor]]:
|
| 115 |
-
"""Concatenate all accumulated batches, return, and clear self"""
|
| 116 |
-
batch = {}
|
| 117 |
-
for k, v in self._batches.items():
|
| 118 |
-
if k == f"{self.key_to_keep}_t0_idx":
|
| 119 |
-
batch[k] = v[0]
|
| 120 |
-
else:
|
| 121 |
-
batch[k] = torch.cat(v, dim=0)
|
| 122 |
-
self._batches = {}
|
| 123 |
-
return batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/optimizers.py
DELETED
|
@@ -1,200 +0,0 @@
|
|
| 1 |
-
"""Optimizer factory-function classes.
|
| 2 |
-
"""
|
| 3 |
-
|
| 4 |
-
from abc import ABC, abstractmethod
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class AbstractOptimizer(ABC):
|
| 10 |
-
"""Abstract class for optimizer
|
| 11 |
-
|
| 12 |
-
Optimizer classes will be used by model like:
|
| 13 |
-
> OptimizerGenerator = AbstractOptimizer()
|
| 14 |
-
> optimizer = OptimizerGenerator(model)
|
| 15 |
-
The returned object `optimizer` must be something that may be returned by `pytorch_lightning`'s
|
| 16 |
-
`configure_optimizers()` method.
|
| 17 |
-
See :
|
| 18 |
-
https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers
|
| 19 |
-
|
| 20 |
-
"""
|
| 21 |
-
|
| 22 |
-
@abstractmethod
|
| 23 |
-
def __call__(self):
|
| 24 |
-
"""Abstract call"""
|
| 25 |
-
pass
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class Adam(AbstractOptimizer):
|
| 29 |
-
"""Adam optimizer"""
|
| 30 |
-
|
| 31 |
-
def __init__(self, lr=0.0005, **kwargs):
|
| 32 |
-
"""Adam optimizer"""
|
| 33 |
-
self.lr = lr
|
| 34 |
-
self.kwargs = kwargs
|
| 35 |
-
|
| 36 |
-
def __call__(self, model):
|
| 37 |
-
"""Return optimizer"""
|
| 38 |
-
return torch.optim.Adam(model.parameters(), lr=self.lr, **self.kwargs)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class AdamW(AbstractOptimizer):
|
| 42 |
-
"""AdamW optimizer"""
|
| 43 |
-
|
| 44 |
-
def __init__(self, lr=0.0005, **kwargs):
|
| 45 |
-
"""AdamW optimizer"""
|
| 46 |
-
self.lr = lr
|
| 47 |
-
self.kwargs = kwargs
|
| 48 |
-
|
| 49 |
-
def __call__(self, model):
|
| 50 |
-
"""Return optimizer"""
|
| 51 |
-
return torch.optim.AdamW(model.parameters(), lr=self.lr, **self.kwargs)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def find_submodule_parameters(model, search_modules):
|
| 55 |
-
"""Finds all parameters within given submodule types
|
| 56 |
-
|
| 57 |
-
Args:
|
| 58 |
-
model: torch Module to search through
|
| 59 |
-
search_modules: List of submodule types to search for
|
| 60 |
-
"""
|
| 61 |
-
if isinstance(model, search_modules):
|
| 62 |
-
return model.parameters()
|
| 63 |
-
|
| 64 |
-
children = list(model.children())
|
| 65 |
-
if len(children) == 0:
|
| 66 |
-
return []
|
| 67 |
-
else:
|
| 68 |
-
params = []
|
| 69 |
-
for c in children:
|
| 70 |
-
params += find_submodule_parameters(c, search_modules)
|
| 71 |
-
return params
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def find_other_than_submodule_parameters(model, ignore_modules):
|
| 75 |
-
"""Finds all parameters not with given submodule types
|
| 76 |
-
|
| 77 |
-
Args:
|
| 78 |
-
model: torch Module to search through
|
| 79 |
-
ignore_modules: List of submodule types to ignore
|
| 80 |
-
"""
|
| 81 |
-
if isinstance(model, ignore_modules):
|
| 82 |
-
return []
|
| 83 |
-
|
| 84 |
-
children = list(model.children())
|
| 85 |
-
if len(children) == 0:
|
| 86 |
-
return model.parameters()
|
| 87 |
-
else:
|
| 88 |
-
params = []
|
| 89 |
-
for c in children:
|
| 90 |
-
params += find_other_than_submodule_parameters(c, ignore_modules)
|
| 91 |
-
return params
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
class EmbAdamWReduceLROnPlateau(AbstractOptimizer):
|
| 95 |
-
"""AdamW optimizer and reduce on plateau scheduler"""
|
| 96 |
-
|
| 97 |
-
def __init__(
|
| 98 |
-
self, lr=0.0005, weight_decay=0.01, patience=3, factor=0.5, threshold=2e-4, **opt_kwargs
|
| 99 |
-
):
|
| 100 |
-
"""AdamW optimizer and reduce on plateau scheduler"""
|
| 101 |
-
self.lr = lr
|
| 102 |
-
self.weight_decay = weight_decay
|
| 103 |
-
self.patience = patience
|
| 104 |
-
self.factor = factor
|
| 105 |
-
self.threshold = threshold
|
| 106 |
-
self.opt_kwargs = opt_kwargs
|
| 107 |
-
|
| 108 |
-
def __call__(self, model):
|
| 109 |
-
"""Return optimizer"""
|
| 110 |
-
|
| 111 |
-
search_modules = (torch.nn.Embedding,)
|
| 112 |
-
|
| 113 |
-
no_decay = find_submodule_parameters(model, search_modules)
|
| 114 |
-
decay = find_other_than_submodule_parameters(model, search_modules)
|
| 115 |
-
|
| 116 |
-
optim_groups = [
|
| 117 |
-
{"params": decay, "weight_decay": self.weight_decay},
|
| 118 |
-
{"params": no_decay, "weight_decay": 0.0},
|
| 119 |
-
]
|
| 120 |
-
opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs)
|
| 121 |
-
|
| 122 |
-
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 123 |
-
opt,
|
| 124 |
-
factor=self.factor,
|
| 125 |
-
patience=self.patience,
|
| 126 |
-
threshold=self.threshold,
|
| 127 |
-
)
|
| 128 |
-
sch = {
|
| 129 |
-
"scheduler": sch,
|
| 130 |
-
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
|
| 131 |
-
}
|
| 132 |
-
return [opt], [sch]
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
class AdamWReduceLROnPlateau(AbstractOptimizer):
|
| 136 |
-
"""AdamW optimizer and reduce on plateau scheduler"""
|
| 137 |
-
|
| 138 |
-
def __init__(
|
| 139 |
-
self, lr=0.0005, patience=3, factor=0.5, threshold=2e-4, step_freq=None, **opt_kwargs
|
| 140 |
-
):
|
| 141 |
-
"""AdamW optimizer and reduce on plateau scheduler"""
|
| 142 |
-
self._lr = lr
|
| 143 |
-
self.patience = patience
|
| 144 |
-
self.factor = factor
|
| 145 |
-
self.threshold = threshold
|
| 146 |
-
self.step_freq = step_freq
|
| 147 |
-
self.opt_kwargs = opt_kwargs
|
| 148 |
-
|
| 149 |
-
def _call_multi(self, model):
|
| 150 |
-
remaining_params = {k: p for k, p in model.named_parameters()}
|
| 151 |
-
|
| 152 |
-
group_args = []
|
| 153 |
-
|
| 154 |
-
for key in self._lr.keys():
|
| 155 |
-
if key == "default":
|
| 156 |
-
continue
|
| 157 |
-
|
| 158 |
-
submodule_params = []
|
| 159 |
-
for param_name in list(remaining_params.keys()):
|
| 160 |
-
if param_name.startswith(key):
|
| 161 |
-
submodule_params += [remaining_params.pop(param_name)]
|
| 162 |
-
|
| 163 |
-
group_args += [{"params": submodule_params, "lr": self._lr[key]}]
|
| 164 |
-
|
| 165 |
-
remaining_params = [p for k, p in remaining_params.items()]
|
| 166 |
-
group_args += [{"params": remaining_params}]
|
| 167 |
-
|
| 168 |
-
opt = torch.optim.AdamW(
|
| 169 |
-
group_args, lr=self._lr["default"] if model.lr is None else model.lr, **self.opt_kwargs
|
| 170 |
-
)
|
| 171 |
-
sch = {
|
| 172 |
-
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 173 |
-
opt,
|
| 174 |
-
factor=self.factor,
|
| 175 |
-
patience=self.patience,
|
| 176 |
-
threshold=self.threshold,
|
| 177 |
-
),
|
| 178 |
-
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
|
| 179 |
-
}
|
| 180 |
-
|
| 181 |
-
return [opt], [sch]
|
| 182 |
-
|
| 183 |
-
def __call__(self, model):
|
| 184 |
-
"""Return optimizer"""
|
| 185 |
-
if not isinstance(self._lr, float):
|
| 186 |
-
return self._call_multi(model)
|
| 187 |
-
else:
|
| 188 |
-
default_lr = self._lr if model.lr is None else model.lr
|
| 189 |
-
opt = torch.optim.AdamW(model.parameters(), lr=default_lr, **self.opt_kwargs)
|
| 190 |
-
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 191 |
-
opt,
|
| 192 |
-
factor=self.factor,
|
| 193 |
-
patience=self.patience,
|
| 194 |
-
threshold=self.threshold,
|
| 195 |
-
)
|
| 196 |
-
sch = {
|
| 197 |
-
"scheduler": sch,
|
| 198 |
-
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
|
| 199 |
-
}
|
| 200 |
-
return [opt], [sch]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/training.py
DELETED
|
@@ -1,183 +0,0 @@
|
|
| 1 |
-
"""Training"""
|
| 2 |
-
import os
|
| 3 |
-
import shutil
|
| 4 |
-
from typing import Optional
|
| 5 |
-
|
| 6 |
-
import hydra
|
| 7 |
-
import torch
|
| 8 |
-
from lightning.pytorch import (
|
| 9 |
-
Callback,
|
| 10 |
-
LightningDataModule,
|
| 11 |
-
LightningModule,
|
| 12 |
-
Trainer,
|
| 13 |
-
seed_everything,
|
| 14 |
-
)
|
| 15 |
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
| 16 |
-
from lightning.pytorch.loggers import Logger
|
| 17 |
-
from lightning.pytorch.loggers.wandb import WandbLogger
|
| 18 |
-
from omegaconf import DictConfig, OmegaConf
|
| 19 |
-
|
| 20 |
-
from pvnet import utils
|
| 21 |
-
|
| 22 |
-
log = utils.get_logger(__name__)
|
| 23 |
-
|
| 24 |
-
torch.set_default_dtype(torch.float32)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def _callbacks_to_phase(callbacks, phase):
|
| 28 |
-
for c in callbacks:
|
| 29 |
-
if hasattr(c, "switch_phase"):
|
| 30 |
-
c.switch_phase(phase)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def resolve_monitor_loss(output_quantiles):
|
| 34 |
-
"""Return the desired metric to monitor based on whether quantile regression is being used.
|
| 35 |
-
|
| 36 |
-
The adds the option to use something like:
|
| 37 |
-
monitor: "${resolve_monitor_loss:${model.output_quantiles}}"
|
| 38 |
-
|
| 39 |
-
in early stopping and model checkpoint callbacks so the callbacks config does not need to be
|
| 40 |
-
modified depending on whether quantile regression is being used or not.
|
| 41 |
-
"""
|
| 42 |
-
if output_quantiles is None:
|
| 43 |
-
return "MAE/val"
|
| 44 |
-
else:
|
| 45 |
-
return "quantile_loss/val"
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
OmegaConf.register_new_resolver("resolve_monitor_loss", resolve_monitor_loss)
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def train(config: DictConfig) -> Optional[float]:
|
| 52 |
-
"""Contains training pipeline.
|
| 53 |
-
|
| 54 |
-
Instantiates all PyTorch Lightning objects from config.
|
| 55 |
-
|
| 56 |
-
Args:
|
| 57 |
-
config (DictConfig): Configuration composed by Hydra.
|
| 58 |
-
|
| 59 |
-
Returns:
|
| 60 |
-
Optional[float]: Metric score for hyperparameter optimization.
|
| 61 |
-
"""
|
| 62 |
-
|
| 63 |
-
# Set seed for random number generators in pytorch, numpy and python.random
|
| 64 |
-
if "seed" in config:
|
| 65 |
-
seed_everything(config.seed, workers=True)
|
| 66 |
-
|
| 67 |
-
# Init lightning datamodule
|
| 68 |
-
log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
|
| 69 |
-
datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
|
| 70 |
-
|
| 71 |
-
# Init lightning model
|
| 72 |
-
log.info(f"Instantiating model <{config.model._target_}>")
|
| 73 |
-
model: LightningModule = hydra.utils.instantiate(config.model)
|
| 74 |
-
|
| 75 |
-
# Init lightning loggers
|
| 76 |
-
loggers: list[Logger] = []
|
| 77 |
-
if "logger" in config:
|
| 78 |
-
for _, lg_conf in config.logger.items():
|
| 79 |
-
if "_target_" in lg_conf:
|
| 80 |
-
log.info(f"Instantiating logger <{lg_conf._target_}>")
|
| 81 |
-
loggers.append(hydra.utils.instantiate(lg_conf))
|
| 82 |
-
|
| 83 |
-
# Init lightning callbacks
|
| 84 |
-
callbacks: list[Callback] = []
|
| 85 |
-
if "callbacks" in config:
|
| 86 |
-
for _, cb_conf in config.callbacks.items():
|
| 87 |
-
if "_target_" in cb_conf:
|
| 88 |
-
log.info(f"Instantiating callback <{cb_conf._target_}>")
|
| 89 |
-
callbacks.append(hydra.utils.instantiate(cb_conf))
|
| 90 |
-
|
| 91 |
-
# Align the wandb id with the checkpoint path
|
| 92 |
-
# - only works if wandb logger and model checkpoint used
|
| 93 |
-
# - this makes it easy to push the model to huggingface
|
| 94 |
-
use_wandb_logger = False
|
| 95 |
-
for logger in loggers:
|
| 96 |
-
log.info(f"{logger}")
|
| 97 |
-
if isinstance(logger, WandbLogger):
|
| 98 |
-
use_wandb_logger = True
|
| 99 |
-
wandb_logger = logger
|
| 100 |
-
break
|
| 101 |
-
|
| 102 |
-
if use_wandb_logger:
|
| 103 |
-
for callback in callbacks:
|
| 104 |
-
log.info(f"{callback}")
|
| 105 |
-
if isinstance(callback, ModelCheckpoint):
|
| 106 |
-
# Need to call the .experiment property to initialise the logger
|
| 107 |
-
wandb_logger.experiment
|
| 108 |
-
callback.dirpath = "/".join(
|
| 109 |
-
callback.dirpath.split("/")[:-1] + [wandb_logger.version]
|
| 110 |
-
)
|
| 111 |
-
# Also save model config here - this makes for easy model push to huggingface
|
| 112 |
-
os.makedirs(callback.dirpath, exist_ok=True)
|
| 113 |
-
OmegaConf.save(config.model, f"{callback.dirpath}/model_config.yaml")
|
| 114 |
-
|
| 115 |
-
# Similarly save the data config
|
| 116 |
-
data_config = config.datamodule.configuration
|
| 117 |
-
if data_config is None:
|
| 118 |
-
# Data config can be none if using presaved batches. We go to the presaved
|
| 119 |
-
# batches to get the data config
|
| 120 |
-
data_config = f"{config.datamodule.sample_dir}/data_configuration.yaml"
|
| 121 |
-
|
| 122 |
-
assert os.path.isfile(data_config), f"Data config file not found: {data_config}"
|
| 123 |
-
shutil.copyfile(data_config, f"{callback.dirpath}/data_config.yaml")
|
| 124 |
-
|
| 125 |
-
# upload configuration up to wandb
|
| 126 |
-
OmegaConf.save(config, "./experiment_config.yaml")
|
| 127 |
-
wandb_logger.experiment.save(
|
| 128 |
-
f"{callback.dirpath}/data_config.yaml", callback.dirpath
|
| 129 |
-
)
|
| 130 |
-
wandb_logger.experiment.save("./experiment_config.yaml")
|
| 131 |
-
|
| 132 |
-
break
|
| 133 |
-
|
| 134 |
-
should_pretrain = False
|
| 135 |
-
for c in callbacks:
|
| 136 |
-
should_pretrain |= hasattr(c, "training_phase") and c.training_phase == "pretrain"
|
| 137 |
-
|
| 138 |
-
if should_pretrain:
|
| 139 |
-
_callbacks_to_phase(callbacks, "pretrain")
|
| 140 |
-
|
| 141 |
-
trainer: Trainer = hydra.utils.instantiate(
|
| 142 |
-
config.trainer,
|
| 143 |
-
logger=loggers,
|
| 144 |
-
_convert_="partial",
|
| 145 |
-
callbacks=callbacks,
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
# TODO: remove this option
|
| 149 |
-
if should_pretrain:
|
| 150 |
-
# Pre-train the model
|
| 151 |
-
raise NotImplementedError("Pre-training is not yet supported")
|
| 152 |
-
# The parameter `block_nwp_and_sat` is not available in data-sampler
|
| 153 |
-
# If pretraining is re-supported in the future it is likely any pre-training logic should
|
| 154 |
-
# go here or perhaps in the callbacks
|
| 155 |
-
# datamodule.block_nwp_and_sat = True
|
| 156 |
-
|
| 157 |
-
trainer.fit(model=model, datamodule=datamodule)
|
| 158 |
-
|
| 159 |
-
_callbacks_to_phase(callbacks, "main")
|
| 160 |
-
|
| 161 |
-
trainer.should_stop = False
|
| 162 |
-
|
| 163 |
-
# Train the model completely
|
| 164 |
-
trainer.fit(model=model, datamodule=datamodule)
|
| 165 |
-
|
| 166 |
-
# Make sure everything closed properly
|
| 167 |
-
log.info("Finalizing!")
|
| 168 |
-
utils.finish(
|
| 169 |
-
config=config,
|
| 170 |
-
model=model,
|
| 171 |
-
datamodule=datamodule,
|
| 172 |
-
trainer=trainer,
|
| 173 |
-
callbacks=callbacks,
|
| 174 |
-
loggers=loggers,
|
| 175 |
-
)
|
| 176 |
-
|
| 177 |
-
# Print path to best checkpoint
|
| 178 |
-
log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}")
|
| 179 |
-
|
| 180 |
-
# Return metric score for hyperparameter optimization
|
| 181 |
-
optimized_metric = config.get("optimized_metric")
|
| 182 |
-
if optimized_metric:
|
| 183 |
-
return trainer.callback_metrics[optimized_metric]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pvnet/utils.py
DELETED
|
@@ -1,321 +0,0 @@
|
|
| 1 |
-
"""Utils"""
|
| 2 |
-
import logging
|
| 3 |
-
import warnings
|
| 4 |
-
from collections.abc import Sequence
|
| 5 |
-
from typing import Optional
|
| 6 |
-
|
| 7 |
-
import lightning.pytorch as pl
|
| 8 |
-
import matplotlib.pyplot as plt
|
| 9 |
-
import pandas as pd
|
| 10 |
-
import pylab
|
| 11 |
-
import rich.syntax
|
| 12 |
-
import rich.tree
|
| 13 |
-
import xarray as xr
|
| 14 |
-
from lightning.pytorch.loggers import Logger
|
| 15 |
-
from lightning.pytorch.utilities import rank_zero_only
|
| 16 |
-
from ocf_data_sampler.select.location import Location
|
| 17 |
-
from omegaconf import DictConfig, OmegaConf
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
|
| 21 |
-
"""Initializes multi-GPU-friendly python logger."""
|
| 22 |
-
|
| 23 |
-
logger = logging.getLogger(name)
|
| 24 |
-
logger.setLevel(level)
|
| 25 |
-
|
| 26 |
-
# this ensures all logging levels get marked with the rank zero decorator
|
| 27 |
-
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
| 28 |
-
for level in (
|
| 29 |
-
"debug",
|
| 30 |
-
"info",
|
| 31 |
-
"warning",
|
| 32 |
-
"error",
|
| 33 |
-
"exception",
|
| 34 |
-
"fatal",
|
| 35 |
-
"critical",
|
| 36 |
-
):
|
| 37 |
-
setattr(logger, level, rank_zero_only(getattr(logger, level)))
|
| 38 |
-
|
| 39 |
-
return logger
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class GSPLocationLookup:
|
| 43 |
-
"""Query object for GSP location from GSP ID"""
|
| 44 |
-
|
| 45 |
-
def __init__(self, x_osgb: xr.DataArray, y_osgb: xr.DataArray):
|
| 46 |
-
"""Query object for GSP location from GSP ID
|
| 47 |
-
|
| 48 |
-
Args:
|
| 49 |
-
x_osgb: DataArray of the OSGB x-coordinate for any given GSP ID
|
| 50 |
-
y_osgb: DataArray of the OSGB y-coordinate for any given GSP ID
|
| 51 |
-
|
| 52 |
-
"""
|
| 53 |
-
self.x_osgb = x_osgb
|
| 54 |
-
self.y_osgb = y_osgb
|
| 55 |
-
|
| 56 |
-
def __call__(self, gsp_id: int) -> Location:
|
| 57 |
-
"""Returns the locations for the input GSP IDs.
|
| 58 |
-
|
| 59 |
-
Args:
|
| 60 |
-
gsp_id: Integer ID of the GSP
|
| 61 |
-
"""
|
| 62 |
-
return Location(
|
| 63 |
-
x=self.x_osgb.sel(gsp_id=gsp_id).item(),
|
| 64 |
-
y=self.y_osgb.sel(gsp_id=gsp_id).item(),
|
| 65 |
-
id=gsp_id,
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
class SiteLocationLookup:
|
| 70 |
-
"""Query object for site location from site ID"""
|
| 71 |
-
|
| 72 |
-
def __init__(self, long: xr.DataArray, lat: xr.DataArray):
|
| 73 |
-
"""Query object for site location from site ID
|
| 74 |
-
|
| 75 |
-
Args:
|
| 76 |
-
long: DataArray of the longitude coordinates for any given site ID
|
| 77 |
-
lat: DataArray of the latitude coordinates for any given site ID
|
| 78 |
-
|
| 79 |
-
"""
|
| 80 |
-
self.longitude = long
|
| 81 |
-
self.latitude = lat
|
| 82 |
-
|
| 83 |
-
def __call__(self, site_id: int) -> Location:
|
| 84 |
-
"""Returns the locations for the input site IDs.
|
| 85 |
-
|
| 86 |
-
Args:
|
| 87 |
-
site_id: Integer ID of the site
|
| 88 |
-
"""
|
| 89 |
-
return Location(
|
| 90 |
-
coordinate_system="lon_lat",
|
| 91 |
-
x=self.longitude.sel(pv_system_id=site_id).item(),
|
| 92 |
-
y=self.latitude.sel(pv_system_id=site_id).item(),
|
| 93 |
-
id=site_id,
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def extras(config: DictConfig) -> None:
|
| 98 |
-
"""A couple of optional utilities.
|
| 99 |
-
|
| 100 |
-
Controlled by main config file:
|
| 101 |
-
- disabling warnings
|
| 102 |
-
- easier access to debug mode
|
| 103 |
-
- forcing debug friendly configuration
|
| 104 |
-
|
| 105 |
-
Modifies DictConfig in place.
|
| 106 |
-
|
| 107 |
-
Args:
|
| 108 |
-
config (DictConfig): Configuration composed by Hydra.
|
| 109 |
-
"""
|
| 110 |
-
|
| 111 |
-
log = get_logger()
|
| 112 |
-
|
| 113 |
-
# enable adding new keys to config
|
| 114 |
-
OmegaConf.set_struct(config, False)
|
| 115 |
-
|
| 116 |
-
# disable python warnings if <config.ignore_warnings=True>
|
| 117 |
-
if config.get("ignore_warnings"):
|
| 118 |
-
log.info("Disabling python warnings! <config.ignore_warnings=True>")
|
| 119 |
-
warnings.filterwarnings("ignore")
|
| 120 |
-
|
| 121 |
-
# set <config.trainer.fast_dev_run=True> if <config.debug=True>
|
| 122 |
-
if config.get("debug"):
|
| 123 |
-
log.info("Running in debug mode! <config.debug=True>")
|
| 124 |
-
config.trainer.fast_dev_run = True
|
| 125 |
-
|
| 126 |
-
# force debugger friendly configuration if <config.trainer.fast_dev_run=True>
|
| 127 |
-
if config.trainer.get("fast_dev_run"):
|
| 128 |
-
log.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
|
| 129 |
-
# Debuggers don't like GPUs or multiprocessing
|
| 130 |
-
if config.trainer.get("gpus"):
|
| 131 |
-
config.trainer.gpus = 0
|
| 132 |
-
if config.datamodule.get("pin_memory"):
|
| 133 |
-
config.datamodule.pin_memory = False
|
| 134 |
-
if config.datamodule.get("num_workers"):
|
| 135 |
-
config.datamodule.num_workers = 0
|
| 136 |
-
|
| 137 |
-
# disable adding new keys to config
|
| 138 |
-
OmegaConf.set_struct(config, True)
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
@rank_zero_only
|
| 142 |
-
def print_config(
|
| 143 |
-
config: DictConfig,
|
| 144 |
-
fields: Sequence[str] = (
|
| 145 |
-
"trainer",
|
| 146 |
-
"model",
|
| 147 |
-
"datamodule",
|
| 148 |
-
"callbacks",
|
| 149 |
-
"logger",
|
| 150 |
-
"seed",
|
| 151 |
-
),
|
| 152 |
-
resolve: bool = True,
|
| 153 |
-
) -> None:
|
| 154 |
-
"""Prints content of DictConfig using Rich library and its tree structure.
|
| 155 |
-
|
| 156 |
-
Args:
|
| 157 |
-
config (DictConfig): Configuration composed by Hydra.
|
| 158 |
-
fields (Sequence[str], optional): Determines which main fields from config will
|
| 159 |
-
be printed and in what order.
|
| 160 |
-
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
|
| 161 |
-
"""
|
| 162 |
-
|
| 163 |
-
style = "dim"
|
| 164 |
-
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
| 165 |
-
|
| 166 |
-
for field in fields:
|
| 167 |
-
branch = tree.add(field, style=style, guide_style=style)
|
| 168 |
-
|
| 169 |
-
config_section = config.get(field)
|
| 170 |
-
branch_content = str(config_section)
|
| 171 |
-
if isinstance(config_section, DictConfig):
|
| 172 |
-
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
|
| 173 |
-
|
| 174 |
-
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
| 175 |
-
|
| 176 |
-
rich.print(tree)
|
| 177 |
-
|
| 178 |
-
with open("config_tree.txt", "w") as fp:
|
| 179 |
-
rich.print(tree, file=fp)
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
def empty(*args, **kwargs):
|
| 183 |
-
"""Returns nothing"""
|
| 184 |
-
pass
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
@rank_zero_only
|
| 188 |
-
def log_hyperparameters(
|
| 189 |
-
config: DictConfig,
|
| 190 |
-
model: pl.LightningModule,
|
| 191 |
-
datamodule: pl.LightningDataModule,
|
| 192 |
-
trainer: pl.Trainer,
|
| 193 |
-
callbacks: list[pl.Callback],
|
| 194 |
-
logger: list[Logger],
|
| 195 |
-
) -> None:
|
| 196 |
-
"""This method controls which parameters from Hydra config are saved by Lightning loggers.
|
| 197 |
-
|
| 198 |
-
Additionaly saves:
|
| 199 |
-
- number of trainable model parameters
|
| 200 |
-
"""
|
| 201 |
-
|
| 202 |
-
hparams = {}
|
| 203 |
-
|
| 204 |
-
# choose which parts of hydra config will be saved to loggers
|
| 205 |
-
hparams["trainer"] = config["trainer"]
|
| 206 |
-
hparams["model"] = config["model"]
|
| 207 |
-
hparams["datamodule"] = config["datamodule"]
|
| 208 |
-
if "seed" in config:
|
| 209 |
-
hparams["seed"] = config["seed"]
|
| 210 |
-
if "callbacks" in config:
|
| 211 |
-
hparams["callbacks"] = config["callbacks"]
|
| 212 |
-
|
| 213 |
-
# save number of model parameters
|
| 214 |
-
hparams["model/params_total"] = sum(p.numel() for p in model.parameters())
|
| 215 |
-
hparams["model/params_trainable"] = sum(
|
| 216 |
-
p.numel() for p in model.parameters() if p.requires_grad
|
| 217 |
-
)
|
| 218 |
-
hparams["model/params_not_trainable"] = sum(
|
| 219 |
-
p.numel() for p in model.parameters() if not p.requires_grad
|
| 220 |
-
)
|
| 221 |
-
|
| 222 |
-
# send hparams to all loggers
|
| 223 |
-
trainer.logger.log_hyperparams(hparams)
|
| 224 |
-
|
| 225 |
-
# disable logging any more hyperparameters for all loggers
|
| 226 |
-
# this is just a trick to prevent trainer from logging hparams of model,
|
| 227 |
-
# since we already did that above
|
| 228 |
-
trainer.logger.log_hyperparams = empty
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
def finish(
|
| 232 |
-
config: DictConfig,
|
| 233 |
-
model: pl.LightningModule,
|
| 234 |
-
datamodule: pl.LightningDataModule,
|
| 235 |
-
trainer: pl.Trainer,
|
| 236 |
-
callbacks: list[pl.Callback],
|
| 237 |
-
loggers: list[Logger],
|
| 238 |
-
) -> None:
|
| 239 |
-
"""Makes sure everything closed properly."""
|
| 240 |
-
|
| 241 |
-
# without this sweeps with wandb logger might crash!
|
| 242 |
-
if any([isinstance(logger, pl.loggers.wandb.WandbLogger) for logger in loggers]):
|
| 243 |
-
import wandb
|
| 244 |
-
|
| 245 |
-
wandb.finish()
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
def plot_batch_forecasts(
|
| 249 |
-
batch,
|
| 250 |
-
y_hat,
|
| 251 |
-
batch_idx=None,
|
| 252 |
-
quantiles=None,
|
| 253 |
-
key_to_plot: str = "gsp",
|
| 254 |
-
timesteps_to_plot: Optional[list[int]] = None,
|
| 255 |
-
):
|
| 256 |
-
"""Plot a batch of data and the forecast from that batch"""
|
| 257 |
-
|
| 258 |
-
def _get_numpy(key):
|
| 259 |
-
return batch[key].cpu().numpy().squeeze()
|
| 260 |
-
|
| 261 |
-
y_key = key_to_plot
|
| 262 |
-
y_id_key = f"{key_to_plot}_id"
|
| 263 |
-
time_utc_key = f"{key_to_plot}_time_utc"
|
| 264 |
-
y = batch[y_key].cpu().numpy() # Select the one it is trained on
|
| 265 |
-
y_hat = y_hat.cpu().numpy()
|
| 266 |
-
# Select between the timesteps in timesteps to plot
|
| 267 |
-
plotting_name = key_to_plot.upper()
|
| 268 |
-
|
| 269 |
-
gsp_ids = batch[y_id_key].cpu().numpy().squeeze()
|
| 270 |
-
|
| 271 |
-
times_utc = batch[time_utc_key].cpu().numpy().squeeze().astype("datetime64[ns]")
|
| 272 |
-
times_utc = [pd.to_datetime(t) for t in times_utc]
|
| 273 |
-
if timesteps_to_plot is not None:
|
| 274 |
-
y = y[:, timesteps_to_plot[0] : timesteps_to_plot[1]]
|
| 275 |
-
y_hat = y_hat[:, timesteps_to_plot[0] : timesteps_to_plot[1]]
|
| 276 |
-
times_utc = [t[timesteps_to_plot[0] : timesteps_to_plot[1]] for t in times_utc]
|
| 277 |
-
|
| 278 |
-
batch_size = y.shape[0]
|
| 279 |
-
|
| 280 |
-
fig, axes = plt.subplots(4, 4, figsize=(16, 16))
|
| 281 |
-
|
| 282 |
-
for i, ax in enumerate(axes.ravel()):
|
| 283 |
-
if i >= batch_size:
|
| 284 |
-
ax.axis("off")
|
| 285 |
-
continue
|
| 286 |
-
ax.plot(times_utc[i], y[i], marker=".", color="k", label=r"$y$")
|
| 287 |
-
|
| 288 |
-
if quantiles is None:
|
| 289 |
-
ax.plot(
|
| 290 |
-
times_utc[i][-len(y_hat[i]) :], y_hat[i], marker=".", color="r", label=r"$\hat{y}$"
|
| 291 |
-
)
|
| 292 |
-
else:
|
| 293 |
-
cm = pylab.get_cmap("twilight")
|
| 294 |
-
for nq, q in enumerate(quantiles):
|
| 295 |
-
ax.plot(
|
| 296 |
-
times_utc[i][-len(y_hat[i]) :],
|
| 297 |
-
y_hat[i, :, nq],
|
| 298 |
-
color=cm(q),
|
| 299 |
-
label=r"$\hat{y}$" + f"({q})",
|
| 300 |
-
alpha=0.7,
|
| 301 |
-
)
|
| 302 |
-
|
| 303 |
-
ax.set_title(f"ID: {gsp_ids[i]} | {times_utc[i][0].date()}", fontsize="small")
|
| 304 |
-
|
| 305 |
-
xticks = [t for t in times_utc[i] if t.minute == 0][::2]
|
| 306 |
-
ax.set_xticks(ticks=xticks, labels=[f"{t.hour:02}" for t in xticks], rotation=90)
|
| 307 |
-
ax.grid()
|
| 308 |
-
|
| 309 |
-
axes[0, 0].legend(loc="best")
|
| 310 |
-
|
| 311 |
-
for ax in axes[-1, :]:
|
| 312 |
-
ax.set_xlabel("Time (hour of day)")
|
| 313 |
-
|
| 314 |
-
if batch_idx is not None:
|
| 315 |
-
title = f"Normed {plotting_name} output : batch_idx={batch_idx}"
|
| 316 |
-
else:
|
| 317 |
-
title = f"Normed {plotting_name} output"
|
| 318 |
-
plt.suptitle(title)
|
| 319 |
-
plt.tight_layout()
|
| 320 |
-
|
| 321 |
-
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|