peterdudfield commited on
Commit
7bffb2f
·
1 Parent(s): b74423e

Delete pvnet

Browse files
Files changed (37) hide show
  1. pvnet/__init__.py +0 -2
  2. pvnet/callbacks.py +0 -129
  3. pvnet/data/__init__.py +0 -3
  4. pvnet/data/base_datamodule.py +0 -118
  5. pvnet/data/site_datamodule.py +0 -53
  6. pvnet/data/uk_regional_datamodule.py +0 -54
  7. pvnet/load_model.py +0 -71
  8. pvnet/models/__init__.py +0 -1
  9. pvnet/models/base_model.py +0 -973
  10. pvnet/models/baseline/__init__.py +0 -1
  11. pvnet/models/baseline/last_value.py +0 -42
  12. pvnet/models/baseline/readme.md +0 -5
  13. pvnet/models/baseline/single_value.py +0 -36
  14. pvnet/models/ensemble.py +0 -74
  15. pvnet/models/model_cards/pv_india_model_card_template.md +0 -56
  16. pvnet/models/model_cards/pv_uk_regional_model_card_template.md +0 -59
  17. pvnet/models/model_cards/wind_india_model_card_template.md +0 -56
  18. pvnet/models/multimodal/__init__.py +0 -1
  19. pvnet/models/multimodal/basic_blocks.py +0 -104
  20. pvnet/models/multimodal/encoders/__init__.py +0 -1
  21. pvnet/models/multimodal/encoders/basic_blocks.py +0 -217
  22. pvnet/models/multimodal/encoders/encoders2d.py +0 -413
  23. pvnet/models/multimodal/encoders/encoders3d.py +0 -402
  24. pvnet/models/multimodal/encoders/encodersRNN.py +0 -141
  25. pvnet/models/multimodal/linear_networks/__init__.py +0 -1
  26. pvnet/models/multimodal/linear_networks/basic_blocks.py +0 -121
  27. pvnet/models/multimodal/linear_networks/networks.py +0 -332
  28. pvnet/models/multimodal/multimodal.py +0 -417
  29. pvnet/models/multimodal/readme.md +0 -11
  30. pvnet/models/multimodal/site_encoders/__init__.py +0 -1
  31. pvnet/models/multimodal/site_encoders/basic_blocks.py +0 -35
  32. pvnet/models/multimodal/site_encoders/encoders.py +0 -284
  33. pvnet/models/multimodal/unimodal_teacher.py +0 -447
  34. pvnet/models/utils.py +0 -123
  35. pvnet/optimizers.py +0 -200
  36. pvnet/training.py +0 -183
  37. pvnet/utils.py +0 -321
pvnet/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- """PVNet"""
2
- __version__ = "4.1.18"
 
 
 
pvnet/callbacks.py DELETED
@@ -1,129 +0,0 @@
1
- """Custom callbacks
2
- """
3
- from lightning.pytorch import Trainer
4
- from lightning.pytorch.callbacks import BaseFinetuning, EarlyStopping, LearningRateFinder
5
- from lightning.pytorch.trainer.states import TrainerFn
6
-
7
-
8
- class PhaseEarlyStopping(EarlyStopping):
9
- """Monitor a validation metric and stop training when it stops improving.
10
-
11
- Only functions in a specific phase of training.
12
- """
13
-
14
- training_phase = None
15
-
16
- def switch_phase(self, phase: str):
17
- """Switch phase of callback"""
18
- if phase == self.training_phase:
19
- self.activate()
20
- else:
21
- self.deactivate()
22
-
23
- def deactivate(self):
24
- """Deactivate callback"""
25
- self.active = False
26
-
27
- def activate(self):
28
- """Activate callback"""
29
- self.active = True
30
-
31
- def _should_skip_check(self, trainer: Trainer) -> bool:
32
- return (
33
- (trainer.state.fn != TrainerFn.FITTING) or (trainer.sanity_checking) or not self.active
34
- )
35
-
36
-
37
- class PretrainEarlyStopping(EarlyStopping):
38
- """Monitor a validation metric and stop training when it stops improving.
39
-
40
- Only functions in the 'pretrain' phase of training.
41
- """
42
-
43
- training_phase = "pretrain"
44
-
45
-
46
- class MainEarlyStopping(EarlyStopping):
47
- """Monitor a validation metric and stop training when it stops improving.
48
-
49
- Only functions in the 'main' phase of training.
50
- """
51
-
52
- training_phase = "main"
53
-
54
-
55
- class PretrainFreeze(BaseFinetuning):
56
- """Freeze the satellite and NWP encoders during pretraining"""
57
-
58
- training_phase = "pretrain"
59
-
60
- def __init__(self):
61
- """Freeze the satellite and NWP encoders during pretraining"""
62
- super().__init__()
63
-
64
- def freeze_before_training(self, pl_module):
65
- """Freeze satellite and NWP encoders before training start"""
66
- # freeze any module you want
67
- modules = []
68
- if pl_module.include_sat:
69
- modules += [pl_module.sat_encoder]
70
- if pl_module.include_nwp:
71
- modules += [pl_module.nwp_encoder]
72
- self.freeze(modules)
73
-
74
- def finetune_function(self, pl_module, current_epoch, optimizer):
75
- """Unfreeze satellite and NWP encoders"""
76
- if not self.active:
77
- modules = []
78
- if pl_module.include_sat:
79
- modules += [pl_module.sat_encoder]
80
- if pl_module.include_nwp:
81
- modules += [pl_module.nwp_encoder]
82
- self.unfreeze_and_add_param_group(
83
- modules=modules,
84
- optimizer=optimizer,
85
- train_bn=True,
86
- )
87
-
88
- def switch_phase(self, phase: str):
89
- """Switch phase of callback"""
90
- if phase == self.training_phase:
91
- self.activate()
92
- else:
93
- self.deactivate()
94
-
95
- def deactivate(self):
96
- """Deactivate callback"""
97
- self.active = False
98
-
99
- def activate(self):
100
- """Activate callback"""
101
- self.active = True
102
-
103
-
104
- class PhasedLearningRateFinder(LearningRateFinder):
105
- """Finds a learning rate at the start of each phase of learning"""
106
-
107
- active = True
108
-
109
- def on_fit_start(self, *args, **kwargs):
110
- """Do nothing"""
111
- return
112
-
113
- def on_train_epoch_start(self, trainer, pl_module):
114
- """Run learning rate finder on epoch start and then deactivate"""
115
- if self.active:
116
- self.lr_find(trainer, pl_module)
117
- self.deactivate()
118
-
119
- def switch_phase(self, phase: str):
120
- """Switch training phase"""
121
- self.activate()
122
-
123
- def deactivate(self):
124
- """Deactivate callback"""
125
- self.active = False
126
-
127
- def activate(self):
128
- """Activate callback"""
129
- self.active = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/data/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- """Data parts"""
2
- from .site_datamodule import SiteDataModule
3
- from .uk_regional_datamodule import DataModule
 
 
 
 
pvnet/data/base_datamodule.py DELETED
@@ -1,118 +0,0 @@
1
- """ Data module for pytorch lightning """
2
-
3
- from glob import glob
4
-
5
- from lightning.pytorch import LightningDataModule
6
- from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
7
- from ocf_data_sampler.torch_datasets.sample.base import (
8
- NumpyBatch,
9
- SampleBase,
10
- TensorBatch,
11
- batch_to_tensor,
12
- )
13
- from torch.utils.data import DataLoader, Dataset
14
-
15
-
16
- def collate_fn(samples: list[NumpyBatch]) -> TensorBatch:
17
- """Convert a list of NumpySample samples to a tensor batch"""
18
- return batch_to_tensor(stack_np_samples_into_batch(samples))
19
-
20
-
21
- class PremadeSamplesDataset(Dataset):
22
- """Dataset to load samples from
23
-
24
- Args:
25
- sample_dir: Path to the directory of pre-saved samples.
26
- sample_class: sample class type to use for save/load/to_numpy
27
- """
28
-
29
- def __init__(self, sample_dir: str, sample_class: SampleBase):
30
- """Initialise PremadeSamplesDataset"""
31
- self.sample_paths = glob(f"{sample_dir}/*")
32
- self.sample_class = sample_class
33
-
34
- def __len__(self):
35
- return len(self.sample_paths)
36
-
37
- def __getitem__(self, idx):
38
- sample = self.sample_class.load(self.sample_paths[idx])
39
- return sample.to_numpy()
40
-
41
-
42
- class BaseDataModule(LightningDataModule):
43
- """Base Datamodule for training pvnet and using pvnet pipeline in ocf-data-sampler."""
44
-
45
- def __init__(
46
- self,
47
- configuration: str | None = None,
48
- sample_dir: str | None = None,
49
- batch_size: int = 16,
50
- num_workers: int = 0,
51
- prefetch_factor: int | None = None,
52
- train_period: list[str | None] = [None, None],
53
- val_period: list[str | None] = [None, None],
54
- ):
55
- """Base Datamodule for training pvnet architecture.
56
-
57
- Can also be used with pre-made batches if `sample_dir` is set.
58
-
59
- Args:
60
- configuration: Path to ocf-data-sampler configuration file.
61
- sample_dir: Path to the directory of pre-saved samples. Cannot be used together with
62
- `configuration` or '[train/val]_period'.
63
- batch_size: Batch size.
64
- num_workers: Number of workers to use in multiprocess batch loading.
65
- prefetch_factor: Number of data will be prefetched at the end of each worker process.
66
- train_period: Date range filter for train dataloader.
67
- val_period: Date range filter for val dataloader.
68
-
69
- """
70
- super().__init__()
71
-
72
- if not ((sample_dir is not None) ^ (configuration is not None)):
73
- raise ValueError("Exactly one of `sample_dir` or `configuration` must be set.")
74
-
75
- if sample_dir is not None:
76
- if any([period != [None, None] for period in [train_period, val_period]]):
77
- raise ValueError("Cannot set `(train/val)_period` with presaved samples")
78
-
79
- self.configuration = configuration
80
- self.sample_dir = sample_dir
81
- self.train_period = train_period
82
- self.val_period = val_period
83
-
84
- self._common_dataloader_kwargs = dict(
85
- batch_size=batch_size,
86
- sampler=None,
87
- batch_sampler=None,
88
- num_workers=num_workers,
89
- collate_fn=collate_fn,
90
- pin_memory=False,
91
- drop_last=False,
92
- timeout=0,
93
- worker_init_fn=None,
94
- prefetch_factor=prefetch_factor,
95
- persistent_workers=False,
96
- )
97
-
98
- def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
99
- raise NotImplementedError
100
-
101
- def _get_premade_samples_dataset(self, subdir) -> Dataset:
102
- raise NotImplementedError
103
-
104
- def train_dataloader(self) -> DataLoader:
105
- """Construct train dataloader"""
106
- if self.sample_dir is not None:
107
- dataset = self._get_premade_samples_dataset("train")
108
- else:
109
- dataset = self._get_streamed_samples_dataset(*self.train_period)
110
- return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)
111
-
112
- def val_dataloader(self) -> DataLoader:
113
- """Construct val dataloader"""
114
- if self.sample_dir is not None:
115
- dataset = self._get_premade_samples_dataset("val")
116
- else:
117
- dataset = self._get_streamed_samples_dataset(*self.val_period)
118
- return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/data/site_datamodule.py DELETED
@@ -1,53 +0,0 @@
1
- """ Data module for pytorch lightning """
2
-
3
- from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset
4
- from ocf_data_sampler.torch_datasets.sample.site import SiteSample
5
- from torch.utils.data import Dataset
6
-
7
- from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset
8
-
9
-
10
- class SiteDataModule(BaseDataModule):
11
- """Datamodule for training pvnet and using pvnet pipeline in `ocf-data-sampler`."""
12
-
13
- def __init__(
14
- self,
15
- configuration: str | None = None,
16
- sample_dir: str | None = None,
17
- batch_size: int = 16,
18
- num_workers: int = 0,
19
- prefetch_factor: int | None = None,
20
- train_period: list[str | None] = [None, None],
21
- val_period: list[str | None] = [None, None],
22
- ):
23
- """Datamodule for training pvnet architecture.
24
-
25
- Can also be used with pre-made batches if `sample_dir` is set.
26
-
27
- Args:
28
- configuration: Path to configuration file.
29
- sample_dir: Path to the directory of pre-saved samples. Cannot be used together with
30
- `configuration` or '[train/val]_period'.
31
- batch_size: Batch size.
32
- num_workers: Number of workers to use in multiprocess batch loading.
33
- prefetch_factor: Number of data will be prefetched at the end of each worker process.
34
- train_period: Date range filter for train dataloader.
35
- val_period: Date range filter for val dataloader.
36
-
37
- """
38
- super().__init__(
39
- configuration=configuration,
40
- sample_dir=sample_dir,
41
- batch_size=batch_size,
42
- num_workers=num_workers,
43
- prefetch_factor=prefetch_factor,
44
- train_period=train_period,
45
- val_period=val_period,
46
- )
47
-
48
- def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
49
- return SitesDataset(self.configuration, start_time=start_time, end_time=end_time)
50
-
51
- def _get_premade_samples_dataset(self, subdir) -> Dataset:
52
- split_dir = f"{self.sample_dir}/{subdir}"
53
- return PremadeSamplesDataset(split_dir, SiteSample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/data/uk_regional_datamodule.py DELETED
@@ -1,54 +0,0 @@
1
- """ Data module for pytorch lightning """
2
-
3
- from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
4
- from ocf_data_sampler.torch_datasets.sample.uk_regional import UKRegionalSample
5
- from torch.utils.data import Dataset
6
-
7
- from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset
8
-
9
-
10
- class DataModule(BaseDataModule):
11
- """Datamodule for training pvnet and using pvnet pipeline in `ocf-data-sampler`."""
12
-
13
- def __init__(
14
- self,
15
- configuration: str | None = None,
16
- sample_dir: str | None = None,
17
- batch_size: int = 16,
18
- num_workers: int = 0,
19
- prefetch_factor: int | None = None,
20
- train_period: list[str | None] = [None, None],
21
- val_period: list[str | None] = [None, None],
22
- ):
23
- """Datamodule for training pvnet architecture.
24
-
25
- Can also be used with pre-made batches if `sample_dir` is set.
26
-
27
- Args:
28
- configuration: Path to configuration file.
29
- sample_dir: Path to the directory of pre-saved samples. Cannot be used together with
30
- `configuration` or '[train/val]_period'.
31
- batch_size: Batch size.
32
- num_workers: Number of workers to use in multiprocess batch loading.
33
- prefetch_factor: Number of data will be prefetched at the end of each worker process.
34
- train_period: Date range filter for train dataloader.
35
- val_period: Date range filter for val dataloader.
36
-
37
- """
38
- super().__init__(
39
- configuration=configuration,
40
- sample_dir=sample_dir,
41
- batch_size=batch_size,
42
- num_workers=num_workers,
43
- prefetch_factor=prefetch_factor,
44
- train_period=train_period,
45
- val_period=val_period,
46
- )
47
-
48
- def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
49
- return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time)
50
-
51
- def _get_premade_samples_dataset(self, subdir) -> Dataset:
52
- split_dir = f"{self.sample_dir}/{subdir}"
53
- # Returns a dict of np arrays
54
- return PremadeSamplesDataset(split_dir, UKRegionalSample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/load_model.py DELETED
@@ -1,71 +0,0 @@
1
- """ Load a model from its checkpoint directory """
2
- import glob
3
- import os
4
-
5
- import hydra
6
- import torch
7
- from pyaml_env import parse_config
8
-
9
- from pvnet.models.ensemble import Ensemble
10
- from pvnet.models.multimodal.unimodal_teacher import Model as UMTModel
11
-
12
-
13
- def get_model_from_checkpoints(
14
- checkpoint_dir_paths: list[str],
15
- val_best: bool = True,
16
- ):
17
- """Load a model from its checkpoint directory"""
18
- is_ensemble = len(checkpoint_dir_paths) > 1
19
-
20
- model_configs = []
21
- models = []
22
- data_configs = []
23
-
24
- for path in checkpoint_dir_paths:
25
- # Load the model
26
- model_config = parse_config(f"{path}/model_config.yaml")
27
-
28
- model = hydra.utils.instantiate(model_config)
29
-
30
- if val_best:
31
- # Only one epoch (best) saved per model
32
- files = glob.glob(f"{path}/epoch*.ckpt")
33
- if len(files) != 1:
34
- raise ValueError(
35
- f"Found {len(files)} checkpoints @ {path}/epoch*.ckpt. Expected one."
36
- )
37
- # TODO: Loading with weights_only=False is not recommended
38
- checkpoint = torch.load(files[0], map_location="cpu", weights_only=False)
39
- else:
40
- checkpoint = torch.load(f"{path}/last.ckpt", map_location="cpu", weights_only=False)
41
-
42
- model.load_state_dict(state_dict=checkpoint["state_dict"])
43
-
44
- if isinstance(model, UMTModel):
45
- model, model_config = model.convert_to_multimodal_model(model_config)
46
-
47
- # Check for data config
48
- data_config = f"{path}/data_config.yaml"
49
-
50
- if os.path.isfile(data_config):
51
- data_configs.append(data_config)
52
- else:
53
- data_configs.append(None)
54
-
55
- model_configs.append(model_config)
56
- models.append(model)
57
-
58
- if is_ensemble:
59
- model_config = {
60
- "_target_": "pvnet.models.ensemble.Ensemble",
61
- "model_list": model_configs,
62
- }
63
- model = Ensemble(model_list=models)
64
- data_config = data_configs[0]
65
-
66
- else:
67
- model_config = model_configs[0]
68
- model = models[0]
69
- data_config = data_configs[0]
70
-
71
- return model, model_config, data_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Models for PVNet"""
 
 
pvnet/models/base_model.py DELETED
@@ -1,973 +0,0 @@
1
- """Base model for all PVNet submodels"""
2
- import copy
3
- import logging
4
- import os
5
- import tempfile
6
- import time
7
- from pathlib import Path
8
- from typing import Dict, Optional, Union
9
-
10
- import hydra
11
- import lightning.pytorch as pl
12
- import matplotlib.pyplot as plt
13
- import pandas as pd
14
- import pkg_resources
15
- import torch
16
- import torch.nn.functional as F
17
- import wandb
18
- import yaml
19
- from huggingface_hub import ModelCard, ModelCardData, PyTorchModelHubMixin
20
- from huggingface_hub.constants import PYTORCH_WEIGHTS_NAME
21
- from huggingface_hub.file_download import hf_hub_download
22
- from huggingface_hub.hf_api import HfApi
23
- from ocf_data_sampler.torch_datasets.sample.base import copy_batch_to_device
24
- from torchvision.transforms.functional import center_crop
25
-
26
- from pvnet.models.utils import (
27
- BatchAccumulator,
28
- MetricAccumulator,
29
- PredAccumulator,
30
- )
31
- from pvnet.optimizers import AbstractOptimizer
32
- from pvnet.utils import plot_batch_forecasts
33
-
34
- DATA_CONFIG_NAME = "data_config.yaml"
35
- MODEL_CONFIG_NAME = "model_config.yaml"
36
-
37
-
38
- logger = logging.getLogger(__name__)
39
-
40
- activities = [torch.profiler.ProfilerActivity.CPU]
41
- if torch.cuda.is_available():
42
- activities.append(torch.profiler.ProfilerActivity.CUDA)
43
-
44
-
45
- def make_clean_data_config(input_path, output_path, placeholder="PLACEHOLDER"):
46
- """Resave the data config and replace the filepaths with a placeholder.
47
-
48
- Args:
49
- input_path: Path to input configuration file
50
- output_path: Location to save the output configuration file
51
- placeholder: String placeholder for data sources
52
- """
53
- with open(input_path) as cfg:
54
- config = yaml.load(cfg, Loader=yaml.FullLoader)
55
-
56
- config["general"]["description"] = "Config for training the saved PVNet model"
57
- config["general"]["name"] = "PVNet current"
58
-
59
- for source in ["gsp", "satellite", "hrvsatellite"]:
60
- if source in config["input_data"]:
61
- # If not empty - i.e. if used
62
- if config["input_data"][source]["zarr_path"] != "":
63
- config["input_data"][source]["zarr_path"] = f"{placeholder}.zarr"
64
-
65
- if "nwp" in config["input_data"]:
66
- for source in config["input_data"]["nwp"]:
67
- if config["input_data"]["nwp"][source]["zarr_path"] != "":
68
- config["input_data"]["nwp"][source]["zarr_path"] = f"{placeholder}.zarr"
69
-
70
- if "pv" in config["input_data"]:
71
- for d in config["input_data"]["pv"]["pv_files_groups"]:
72
- d["pv_filename"] = f"{placeholder}.netcdf"
73
- d["pv_metadata_filename"] = f"{placeholder}.csv"
74
-
75
- if "sensor" in config["input_data"]:
76
- # If not empty - i.e. if used
77
- if config["input_data"][source][f"{source}_filename"] != "":
78
- config["input_data"][source][f"{source}_filename"] = f"{placeholder}.nc"
79
-
80
- with open(output_path, "w") as outfile:
81
- yaml.dump(config, outfile, default_flow_style=False)
82
-
83
-
84
- def minimize_data_config(input_path, output_path, model):
85
- """Strip out parts of the data config which aren't used by the model
86
-
87
- Args:
88
- input_path: Path to input configuration file
89
- output_path: Location to save the output configuration file
90
- model: The PVNet model object
91
- """
92
- with open(input_path) as cfg:
93
- config = yaml.load(cfg, Loader=yaml.FullLoader)
94
-
95
- if "nwp" in config["input_data"]:
96
- if not model.include_nwp:
97
- del config["input_data"]["nwp"]
98
- else:
99
- for nwp_source in list(config["input_data"]["nwp"].keys()):
100
- nwp_config = config["input_data"]["nwp"][nwp_source]
101
-
102
- if nwp_source not in model.nwp_encoders_dict:
103
- # If not used, delete this source from the config
104
- del config["input_data"]["nwp"][nwp_source]
105
- else:
106
- # Replace the image size
107
- nwp_pixel_size = model.nwp_encoders_dict[nwp_source].image_size_pixels
108
- nwp_config["image_size_pixels_height"] = nwp_pixel_size
109
- nwp_config["image_size_pixels_width"] = nwp_pixel_size
110
-
111
- # Replace the interval_end_minutes minutes
112
- nwp_config["interval_end_minutes"] = (
113
- nwp_config["interval_start_minutes"] +
114
- (model.nwp_encoders_dict[nwp_source].sequence_length - 1)
115
- * nwp_config["time_resolution_minutes"]
116
- )
117
-
118
- if "satellite" in config["input_data"]:
119
- if not model.include_sat:
120
- del config["input_data"]["satellite"]
121
- else:
122
- sat_config = config["input_data"]["satellite"]
123
-
124
- # Replace the image size
125
- sat_pixel_size = model.sat_encoder.image_size_pixels
126
- sat_config["image_size_pixels_height"] = sat_pixel_size
127
- sat_config["image_size_pixels_width"] = sat_pixel_size
128
-
129
- # Replace the interval_end_minutes minutes
130
- sat_config["interval_end_minutes"] = (
131
- sat_config["interval_start_minutes"] +
132
- (model.sat_encoder.sequence_length - 1)
133
- * sat_config["time_resolution_minutes"]
134
- )
135
-
136
- if "pv" in config["input_data"]:
137
- if not model.include_pv:
138
- del config["input_data"]["pv"]
139
-
140
- if "gsp" in config["input_data"]:
141
- gsp_config = config["input_data"]["gsp"]
142
-
143
- # Replace the forecast minutes
144
- gsp_config["interval_end_minutes"] = model.forecast_minutes
145
-
146
- if "solar_position" in config["input_data"]:
147
- solar_config = config["input_data"]["solar_position"]
148
- solar_config["interval_end_minutes"] = model.forecast_minutes
149
-
150
- with open(output_path, "w") as outfile:
151
- yaml.dump(config, outfile, default_flow_style=False)
152
-
153
-
154
- def download_hf_hub_with_retries(
155
- repo_id,
156
- filename,
157
- revision,
158
- cache_dir,
159
- force_download,
160
- proxies,
161
- resume_download,
162
- token,
163
- local_files_only,
164
- max_retries=5,
165
- wait_time=10,
166
- ):
167
- """
168
- Tries to download a file from HuggingFace up to max_retries times.
169
-
170
- Args:
171
- repo_id (str): HuggingFace repo ID
172
- filename (str): Name of the file to download
173
- revision (str): Specific model revision
174
- cache_dir (str): Cache directory
175
- force_download (bool): Whether to force a new download
176
- proxies (dict): Proxy settings
177
- resume_download (bool): Resume interrupted downloads
178
- token (str): HuggingFace auth token
179
- local_files_only (bool): Use local files only
180
- max_retries (int): Maximum number of retry attempts
181
- wait_time (int): Wait time (in seconds) before retrying
182
-
183
- Returns:
184
- str: The local file path of the downloaded file
185
- """
186
- for attempt in range(1, max_retries + 1):
187
- try:
188
- return hf_hub_download(
189
- repo_id=repo_id,
190
- filename=filename,
191
- revision=revision,
192
- cache_dir=cache_dir,
193
- force_download=force_download,
194
- proxies=proxies,
195
- resume_download=resume_download,
196
- token=token,
197
- local_files_only=local_files_only,
198
- )
199
- except Exception as e:
200
- if attempt == max_retries:
201
- raise Exception(
202
- f"Failed to download {filename} from {repo_id} after {max_retries} attempts."
203
- ) from e
204
- logging.warning(
205
- (
206
- f"Attempt {attempt}/{max_retries} failed to download {filename} "
207
- f"from {repo_id}. Retrying in {wait_time} seconds..."
208
- )
209
- )
210
- time.sleep(wait_time)
211
-
212
-
213
- class PVNetModelHubMixin(PyTorchModelHubMixin):
214
- """
215
- Implementation of [`PyTorchModelHubMixin`] to provide model Hub upload/download capabilities.
216
- """
217
-
218
- @classmethod
219
- def from_pretrained(
220
- cls,
221
- *,
222
- model_id: str,
223
- revision: str,
224
- cache_dir: Optional[Union[str, Path]] = None,
225
- force_download: bool = False,
226
- proxies: Optional[Dict] = None,
227
- resume_download: Optional[bool] = None,
228
- local_files_only: bool = False,
229
- token: Union[str, bool, None] = None,
230
- map_location: str = "cpu",
231
- strict: bool = False,
232
- ):
233
- """Load Pytorch pretrained weights and return the loaded model."""
234
-
235
- if os.path.isdir(model_id):
236
- print("Loading weights from local directory")
237
- model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
238
- config_file = os.path.join(model_id, MODEL_CONFIG_NAME)
239
- else:
240
- # load model file
241
- model_file = download_hf_hub_with_retries(
242
- repo_id=model_id,
243
- filename=PYTORCH_WEIGHTS_NAME,
244
- revision=revision,
245
- cache_dir=cache_dir,
246
- force_download=force_download,
247
- proxies=proxies,
248
- resume_download=resume_download,
249
- token=token,
250
- local_files_only=local_files_only,
251
- max_retries=5,
252
- wait_time=10,
253
- )
254
-
255
- # load config file
256
- config_file = download_hf_hub_with_retries(
257
- repo_id=model_id,
258
- filename=MODEL_CONFIG_NAME,
259
- revision=revision,
260
- cache_dir=cache_dir,
261
- force_download=force_download,
262
- proxies=proxies,
263
- resume_download=resume_download,
264
- token=token,
265
- local_files_only=local_files_only,
266
- max_retries=5,
267
- wait_time=10,
268
- )
269
-
270
- with open(config_file, "r") as f:
271
- config = yaml.safe_load(f)
272
-
273
- model = hydra.utils.instantiate(config)
274
-
275
- state_dict = torch.load(model_file, map_location=torch.device(map_location))
276
- model.load_state_dict(state_dict, strict=strict) # type: ignore
277
- model.eval() # type: ignore
278
-
279
- return model
280
-
281
- @classmethod
282
- def get_data_config(
283
- cls,
284
- model_id: str,
285
- revision: str,
286
- cache_dir: Optional[Union[str, Path]] = None,
287
- force_download: bool = False,
288
- proxies: Optional[Dict] = None,
289
- resume_download: bool = False,
290
- local_files_only: bool = False,
291
- token: Optional[Union[str, bool]] = None,
292
- ):
293
- """Load data config file."""
294
- if os.path.isdir(model_id):
295
- print("Loading data config from local directory")
296
- data_config_file = os.path.join(model_id, DATA_CONFIG_NAME)
297
- else:
298
- data_config_file = download_hf_hub_with_retries(
299
- repo_id=model_id,
300
- filename=DATA_CONFIG_NAME,
301
- revision=revision,
302
- cache_dir=cache_dir,
303
- force_download=force_download,
304
- proxies=proxies,
305
- resume_download=resume_download,
306
- token=token,
307
- local_files_only=local_files_only,
308
- max_retries=5,
309
- wait_time=10,
310
- )
311
-
312
- return data_config_file
313
-
314
- def _save_pretrained(self, save_directory: Path) -> None:
315
- """Save weights from a Pytorch model to a local directory."""
316
- model_to_save = self.module if hasattr(self, "module") else self # type: ignore
317
- torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)
318
-
319
- def save_pretrained(
320
- self,
321
- save_directory: Union[str, Path],
322
- config: dict,
323
- data_config: Optional[Union[str, Path]],
324
- repo_id: Optional[str] = None,
325
- push_to_hub: bool = False,
326
- wandb_repo: Optional[str] = None,
327
- wandb_ids: Optional[Union[list[str], str]] = None,
328
- card_template_path: Optional[Path] = None,
329
- **kwargs,
330
- ) -> Optional[str]:
331
- """
332
- Save weights in local directory.
333
-
334
- Args:
335
- save_directory (`str` or `Path`):
336
- Path to directory in which the model weights and configuration will be saved.
337
- config (`dict`):
338
- Model configuration specified as a key/value dictionary.
339
- data_config (`str` or `Path`):
340
- The path to the data config.
341
- repo_id (`str`, *optional*):
342
- ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to
343
- the folder name if not provided.
344
- push_to_hub (`bool`, *optional*, defaults to `False`):
345
- Whether or not to push your model to the HuggingFace Hub after saving it.
346
- wandb_repo: Identifier of the repo on wandb.
347
- wandb_ids: Identifier(s) of the model on wandb.
348
- card_template_path: Path to the HuggingFace model card template. Defaults to card in
349
- PVNet library if set to None.
350
- kwargs:
351
- Additional key word arguments passed along to the
352
- [`~ModelHubMixin._from_pretrained`] method.
353
- """
354
-
355
- save_directory = Path(save_directory)
356
- save_directory.mkdir(parents=True, exist_ok=True)
357
-
358
- # saving model weights/files
359
- self._save_pretrained(save_directory)
360
-
361
- # saving model and data config
362
- if isinstance(config, dict):
363
- with open(save_directory / MODEL_CONFIG_NAME, "w") as f:
364
- yaml.dump(config, f, sort_keys=False, default_flow_style=False)
365
-
366
- # Save cleaned configuration file
367
- if data_config is not None:
368
- new_data_config_path = save_directory / DATA_CONFIG_NAME
369
-
370
- # Replace the input filenames with place holders
371
- make_clean_data_config(data_config, new_data_config_path)
372
-
373
- # Taylor the data config to the model being saved
374
- minimize_data_config(new_data_config_path, new_data_config_path, self)
375
-
376
- card = self.create_hugging_face_model_card(
377
- repo_id, wandb_repo, wandb_ids, card_template_path
378
- )
379
-
380
- (save_directory / "README.md").write_text(str(card))
381
-
382
- if push_to_hub:
383
- api = HfApi()
384
-
385
- api.upload_folder(
386
- repo_id=repo_id,
387
- repo_type="model",
388
- folder_path=save_directory,
389
- )
390
-
391
- # Print the most recent commit hash
392
- c = api.list_repo_commits(repo_id=repo_id, repo_type="model")[0]
393
-
394
- message = (
395
- f"The latest commit is now: \n"
396
- f" date: {c.created_at} \n"
397
- f" commit hash: {c.commit_id}\n"
398
- f" by: {c.authors}\n"
399
- f" title: {c.title}\n"
400
- )
401
-
402
- print(message)
403
-
404
- return None
405
-
406
- @staticmethod
407
- def create_hugging_face_model_card(
408
- repo_id: Optional[str] = None,
409
- wandb_repo: Optional[str] = None,
410
- wandb_ids: Optional[Union[list[str], str]] = None,
411
- card_template_path: Optional[Path] = None,
412
- ) -> ModelCard:
413
- """
414
- Creates Hugging Face model card
415
-
416
- Args:
417
- repo_id (`str`, *optional*):
418
- ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to
419
- the folder name if not provided.
420
- wandb_repo: Identifier of the repo on wandb.
421
- wandb_ids: Identifier(s) of the model on wandb.
422
- card_template_path: Path to the HuggingFace model card template. Defaults to card in
423
- PVNet library if set to None.
424
-
425
- Returns:
426
- card: ModelCard - Hugging Face model card object
427
- """
428
-
429
- # Get appropriate model card
430
- model_name = repo_id.split("/")[1]
431
- if model_name == "windnet_india":
432
- model_card = "wind_india_model_card_template.md"
433
- elif model_name == "pvnet_india":
434
- model_card = "pv_india_model_card_template.md"
435
- else:
436
- model_card = "pv_uk_regional_model_card_template.md"
437
-
438
- # Creating and saving model card.
439
- card_data = ModelCardData(language="en", license="mit", library_name="pytorch")
440
- if card_template_path is None:
441
- card_template_path = (
442
- f"{os.path.dirname(os.path.abspath(__file__))}/model_cards/{model_card}"
443
- )
444
-
445
- if isinstance(wandb_ids, str):
446
- wandb_ids = [wandb_ids]
447
-
448
- wandb_links = ""
449
- for wandb_id in wandb_ids:
450
- link = f"https://wandb.ai/{wandb_repo}/runs/{wandb_id}"
451
- wandb_links += f" - [{link}]({link})\n"
452
-
453
- # Find package versions for OCF packages
454
- packages_to_display = ["pvnet", "ocf-data-sampler"]
455
- packages_and_versions = {
456
- package_name: pkg_resources.get_distribution(package_name).version
457
- for package_name in packages_to_display
458
- }
459
-
460
- package_versions_markdown = ""
461
- for package, version in packages_and_versions.items():
462
- package_versions_markdown += f" - {package}=={version}\n"
463
-
464
- return ModelCard.from_template(
465
- card_data,
466
- template_path=card_template_path,
467
- wandb_links=wandb_links,
468
- package_versions=package_versions_markdown
469
- )
470
-
471
-
472
- class BaseModel(pl.LightningModule, PVNetModelHubMixin):
473
- """Abstract base class for PVNet submodels"""
474
-
475
- def __init__(
476
- self,
477
- history_minutes: int,
478
- forecast_minutes: int,
479
- optimizer: AbstractOptimizer,
480
- output_quantiles: Optional[list[float]] = None,
481
- target_key: str = "gsp",
482
- interval_minutes: int = 30,
483
- timestep_intervals_to_plot: Optional[list[int]] = None,
484
- forecast_minutes_ignore: Optional[int] = 0,
485
- save_validation_results_csv: Optional[bool] = False,
486
- ):
487
- """Abtstract base class for PVNet submodels.
488
-
489
- Args:
490
- history_minutes (int): Length of the GSP history period in minutes
491
- forecast_minutes (int): Length of the GSP forecast period in minutes
492
- optimizer (AbstractOptimizer): Optimizer
493
- output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
494
- None the output is a single value.
495
- target_key: The key of the target variable in the batch
496
- interval_minutes: The interval in minutes between each timestep in the data
497
- timestep_intervals_to_plot: Intervals, in timesteps, to plot during training
498
- forecast_minutes_ignore: Number of forecast minutes to ignore when calculating losses.
499
- For example if set to 60, the model doesnt predict the first 60 minutes
500
- save_validation_results_csv: whether to save full csv outputs from validation results.
501
- """
502
- super().__init__()
503
-
504
- self._optimizer = optimizer
505
- self._target_key = target_key
506
- if timestep_intervals_to_plot is not None:
507
- for interval in timestep_intervals_to_plot:
508
- assert type(interval) in [list, tuple] and len(interval) == 2, ValueError(
509
- f"timestep_intervals_to_plot must be a list of tuples or lists of length 2, "
510
- f"but got {timestep_intervals_to_plot=}"
511
- )
512
- self.time_step_intervals_to_plot = timestep_intervals_to_plot
513
-
514
- # Model must have lr to allow tuning
515
- # This setting is only used when lr is tuned with callback
516
- self.lr = None
517
-
518
- self.history_minutes = history_minutes
519
- self.forecast_minutes = forecast_minutes
520
- self.output_quantiles = output_quantiles
521
- self.interval_minutes = interval_minutes
522
- self.forecast_minutes_ignore = forecast_minutes_ignore
523
-
524
- # Number of timestemps for 30 minutely data
525
- self.history_len = history_minutes // interval_minutes
526
- self.forecast_len = (forecast_minutes - forecast_minutes_ignore) // interval_minutes
527
- self.forecast_len_ignore = forecast_minutes_ignore // interval_minutes
528
-
529
- self._accumulated_metrics = MetricAccumulator()
530
- self._accumulated_batches = BatchAccumulator(key_to_keep=self._target_key)
531
- self._accumulated_y_hat = PredAccumulator()
532
- self._horizon_maes = MetricAccumulator()
533
-
534
- # Store whether the model should use quantile regression or simply predict the mean
535
- self.use_quantile_regression = self.output_quantiles is not None
536
-
537
- # Store the number of ouput features that the model should predict for
538
- if self.use_quantile_regression:
539
- self.num_output_features = self.forecast_len * len(self.output_quantiles)
540
- else:
541
- self.num_output_features = self.forecast_len
542
-
543
- # save all validation results to array, so we can save these to weights n biases
544
- self.validation_epoch_results = []
545
- self.save_validation_results_csv = save_validation_results_csv
546
-
547
- def _adapt_batch(self, batch):
548
- """Slice batches into appropriate shapes for model.
549
-
550
- Returns a new batch dictionary with adapted data, leaving the original batch unchanged.
551
- We make some specific assumptions about the original batch and the derived sliced batch:
552
- - We are only limiting the future projections. I.e. we are never shrinking the batch from
553
- the left hand side of the time axis, only slicing it from the right
554
- - We are only shrinking the spatial crop of the satellite and NWP data
555
-
556
- """
557
- # Create a copy of the batch to avoid modifying the original
558
- new_batch = {key: copy.deepcopy(value) for key, value in batch.items()}
559
-
560
- if "gsp" in new_batch.keys():
561
- # Slice off the end of the GSP data
562
- gsp_len = self.forecast_len + self.history_len + 1
563
- new_batch["gsp"] = new_batch["gsp"][:, :gsp_len]
564
- new_batch["gsp_time_utc"] = new_batch["gsp_time_utc"][:, :gsp_len]
565
-
566
- if self.include_sat:
567
- # Slice off the end of the satellite data and spatially crop
568
- # Shape: batch_size, seq_length, channel, height, width
569
- new_batch["satellite_actual"] = center_crop(
570
- new_batch["satellite_actual"][:, : self.sat_sequence_len],
571
- output_size=self.sat_encoder.image_size_pixels,
572
- )
573
-
574
- if self.include_nwp:
575
- # Slice off the end of the NWP data and spatially crop
576
- for nwp_source in self.nwp_encoders_dict:
577
- # shape: batch_size, seq_len, n_chans, height, width
578
- new_batch["nwp"][nwp_source]["nwp"] = center_crop(
579
- new_batch["nwp"][nwp_source]["nwp"],
580
- output_size=self.nwp_encoders_dict[nwp_source].image_size_pixels,
581
- )[:, : self.nwp_encoders_dict[nwp_source].sequence_length]
582
-
583
- if self.include_sun:
584
- sun_len = self.forecast_len + self.history_len + 1
585
- # Slice off end of solar coords
586
- for s in ["solar_azimuth", "solar_elevation"]:
587
- if s in new_batch.keys():
588
- new_batch[s] = new_batch[s][:, :sun_len]
589
-
590
- return new_batch
591
-
592
- def transfer_batch_to_device(self, batch, device, dataloader_idx):
593
- """Method to move custom batches to a given device"""
594
- return copy_batch_to_device(batch, device)
595
-
596
- def _quantiles_to_prediction(self, y_quantiles):
597
- """
598
- Convert network prediction into a point prediction.
599
-
600
- Note:
601
- Implementation copied from:
602
- https://pytorch-forecasting.readthedocs.io/en/stable/_modules/pytorch_forecasting
603
- /metrics/quantile.html#QuantileLoss.loss
604
-
605
- Args:
606
- y_quantiles: Quantile prediction of network
607
-
608
- Returns:
609
- torch.Tensor: Point prediction
610
- """
611
- # y_quantiles Shape: batch_size, seq_length, num_quantiles
612
- idx = self.output_quantiles.index(0.5)
613
- y_median = y_quantiles[..., idx]
614
- return y_median
615
-
616
- def _calculate_quantile_loss(self, y_quantiles, y):
617
- """Calculate quantile loss.
618
-
619
- Note:
620
- Implementation copied from:
621
- https://pytorch-forecasting.readthedocs.io/en/stable/_modules/pytorch_forecasting
622
- /metrics/quantile.html#QuantileLoss.loss
623
-
624
- Args:
625
- y_quantiles: Quantile prediction of network
626
- y: Target values
627
-
628
- Returns:
629
- Quantile loss
630
- """
631
- # calculate quantile loss
632
- losses = []
633
- for i, q in enumerate(self.output_quantiles):
634
- errors = y - y_quantiles[..., i]
635
- losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
636
- losses = 2 * torch.cat(losses, dim=2)
637
-
638
- return losses.mean()
639
-
640
- def _calculate_common_losses(self, y, y_hat):
641
- """Calculate losses common to train, and val"""
642
-
643
- losses = {}
644
-
645
- if self.use_quantile_regression:
646
- losses["quantile_loss"] = self._calculate_quantile_loss(y_hat, y)
647
- y_hat = self._quantiles_to_prediction(y_hat)
648
-
649
- # calculate mse, mae
650
- mse_loss = F.mse_loss(y_hat, y)
651
- mae_loss = F.l1_loss(y_hat, y)
652
-
653
- # TODO: Compute correlation coef using np.corrcoef(tensor with
654
- # shape (2, num_timesteps))[0, 1] on each example, and taking
655
- # the mean across the batch?
656
- losses.update(
657
- {
658
- "MSE": mse_loss,
659
- "MAE": mae_loss,
660
- }
661
- )
662
-
663
- return losses
664
-
665
- def _step_mae_and_mse(self, y, y_hat, dict_key_root):
666
- """Calculate the MSE and MAE at each forecast step"""
667
- losses = {}
668
-
669
- mse_each_step = torch.mean((y_hat - y) ** 2, dim=0)
670
- mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0)
671
-
672
- losses.update({f"MSE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mse_each_step)})
673
- losses.update({f"MAE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mae_each_step)})
674
-
675
- return losses
676
-
677
- def _calculate_val_losses(self, y, y_hat):
678
- """Calculate additional validation losses"""
679
-
680
- losses = {}
681
-
682
- if self.use_quantile_regression:
683
- # Add fraction below each quantile for calibration
684
- for i, quantile in enumerate(self.output_quantiles):
685
- below_quant = y <= y_hat[..., i]
686
- # Mask values small values, which are dominated by night
687
- mask = y >= 0.01
688
- losses[f"fraction_below_{quantile}_quantile"] = (below_quant[mask]).float().mean()
689
-
690
- # Take median value for remaining metric calculations
691
- y_hat = self._quantiles_to_prediction(y_hat)
692
-
693
- # Log the loss at each time horizon
694
- losses.update(self._step_mae_and_mse(y, y_hat, dict_key_root="horizon"))
695
-
696
- # Log the persistance losses
697
- y_persist = y[:, -1].unsqueeze(1).expand(-1, self.forecast_len)
698
- losses["MAE_persistence/val"] = F.l1_loss(y_persist, y)
699
- losses["MSE_persistence/val"] = F.mse_loss(y_persist, y)
700
-
701
- # Log persistance loss at each time horizon
702
- losses.update(self._step_mae_and_mse(y, y_persist, dict_key_root="persistence"))
703
- return losses
704
-
705
- def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):
706
- """Internal function to accumulate training batches and log results.
707
-
708
- This is used when accummulating grad batches. Should make the variability in logged training
709
- step metrics indpendent on whether we accumulate N batches of size B or just use a larger
710
- batch size of N*B with no accumulaion.
711
- """
712
-
713
- losses = {k: v.detach().cpu() for k, v in losses.items()}
714
- y_hat = y_hat.detach().cpu()
715
-
716
- self._accumulated_metrics.append(losses)
717
- self._accumulated_batches.append(batch)
718
- self._accumulated_y_hat.append(y_hat)
719
-
720
- if not self.trainer.fit_loop._should_accumulate():
721
- losses = self._accumulated_metrics.flush()
722
- batch = self._accumulated_batches.flush()
723
- y_hat = self._accumulated_y_hat.flush()
724
-
725
- self.log_dict(
726
- losses,
727
- on_step=True,
728
- on_epoch=True,
729
- )
730
-
731
- # Number of accumulated grad batches
732
- grad_batch_num = (batch_idx + 1) / self.trainer.accumulate_grad_batches
733
-
734
- # We only create the figure every 8 log steps
735
- # This was reduced as it was creating figures too often
736
- if grad_batch_num % (8 * self.trainer.log_every_n_steps) == 0:
737
- fig = plot_batch_forecasts(
738
- batch,
739
- y_hat,
740
- batch_idx,
741
- quantiles=self.output_quantiles,
742
- key_to_plot=self._target_key,
743
- )
744
- fig.savefig("latest_logged_train_batch.png")
745
- plt.close(fig)
746
-
747
- def training_step(self, batch, batch_idx):
748
- """Run training step"""
749
- y_hat = self(batch)
750
-
751
- # Batch is adapted in the model forward method, but needs to be adapted here too
752
- batch = self._adapt_batch(batch)
753
-
754
- y = batch[self._target_key][:, -self.forecast_len :]
755
-
756
- losses = self._calculate_common_losses(y, y_hat)
757
- losses = {f"{k}/train": v for k, v in losses.items()}
758
-
759
- self._training_accumulate_log(batch, batch_idx, losses, y_hat)
760
-
761
- if self.use_quantile_regression:
762
- opt_target = losses["quantile_loss/train"]
763
- else:
764
- opt_target = losses["MAE/train"]
765
- return opt_target
766
-
767
- def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, plot_suffix):
768
- """Log forecast plot to wandb"""
769
- fig = plot_batch_forecasts(
770
- batch,
771
- y_hat,
772
- quantiles=self.output_quantiles,
773
- key_to_plot=self._target_key,
774
- )
775
-
776
- plot_name = f"val_forecast_samples/batch_idx_{accum_batch_num}_{plot_suffix}"
777
-
778
- try:
779
- self.logger.experiment.log({plot_name: wandb.Image(fig)})
780
- except Exception as e:
781
- print(f"Failed to log {plot_name} to wandb")
782
- print(e)
783
- plt.close(fig)
784
-
785
- def _log_validation_results(self, batch, y_hat, accum_batch_num):
786
- """Append validation results to self.validation_epoch_results"""
787
-
788
- # get truth values, shape (b, forecast_len)
789
- y = batch[self._target_key][:, -self.forecast_len :]
790
- y = y.detach().cpu().numpy()
791
- batch_size = y.shape[0]
792
-
793
- # get prediction values, shape (b, forecast_len, quantiles?)
794
- y_hat = y_hat.detach().cpu().numpy()
795
-
796
- # get time_utc, shape (b, forecast_len)
797
- time_utc_key = f"{self._target_key}_time_utc"
798
- time_utc = batch[time_utc_key][:, -self.forecast_len :].detach().cpu().numpy()
799
-
800
- # get target id and change from (b,1) to (b,)
801
- id_key = f"{self._target_key}_id"
802
- target_id = batch[id_key].detach().cpu().numpy()
803
- target_id = target_id.squeeze()
804
-
805
- for i in range(batch_size):
806
- y_i = y[i]
807
- y_hat_i = y_hat[i]
808
- time_utc_i = time_utc[i]
809
- target_id_i = target_id[i]
810
-
811
- results_dict = {
812
- "y": y_i,
813
- "time_utc": time_utc_i,
814
- }
815
- if self.use_quantile_regression:
816
- results_dict.update(
817
- {f"y_quantile_{q}": y_hat_i[:, i] for i, q in enumerate(self.output_quantiles)}
818
- )
819
- else:
820
- results_dict["y_hat"] = y_hat_i
821
-
822
- results_df = pd.DataFrame(results_dict)
823
- results_df["id"] = target_id_i
824
- results_df["batch_idx"] = accum_batch_num
825
- results_df["example_idx"] = i
826
-
827
- self.validation_epoch_results.append(results_df)
828
-
829
- def validation_step(self, batch: dict, batch_idx):
830
- """Run validation step"""
831
-
832
- accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches
833
-
834
- y_hat = self(batch)
835
- # Batch is adapted in the model forward method, but needs to be adapted here too
836
- batch = self._adapt_batch(batch)
837
-
838
- y = batch[self._target_key][:, -self.forecast_len :]
839
-
840
- if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
841
- self._log_validation_results(batch, y_hat, accum_batch_num)
842
-
843
- # Expand persistence to be the same shape as y
844
- losses = self._calculate_common_losses(y, y_hat)
845
- losses.update(self._calculate_val_losses(y, y_hat))
846
-
847
- # Store these to make horizon accuracy plot
848
- self._horizon_maes.append(
849
- {i: losses[f"MAE_horizon/step_{i:03}"].cpu().numpy() for i in range(self.forecast_len)}
850
- )
851
-
852
- logged_losses = {f"{k}/val": v for k, v in losses.items()}
853
-
854
- self.log_dict(
855
- logged_losses,
856
- on_step=False,
857
- on_epoch=True,
858
- )
859
-
860
- # Make plots only if using wandb logger
861
- if isinstance(self.logger, pl.loggers.WandbLogger) and accum_batch_num in [0, 1]:
862
- # Store these temporarily under self
863
- if not hasattr(self, "_val_y_hats"):
864
- self._val_y_hats = PredAccumulator()
865
- self._val_batches = BatchAccumulator(key_to_keep=self._target_key)
866
-
867
- self._val_y_hats.append(y_hat)
868
- self._val_batches.append(batch)
869
-
870
- # if batch has accumulated
871
- if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
872
- y_hat = self._val_y_hats.flush()
873
- batch = self._val_batches.flush()
874
-
875
- self._log_forecast_plot(
876
- batch,
877
- y_hat,
878
- accum_batch_num,
879
- timesteps_to_plot=None,
880
- plot_suffix="all",
881
- )
882
-
883
- if self.time_step_intervals_to_plot is not None:
884
- for interval in self.time_step_intervals_to_plot:
885
- self._log_forecast_plot(
886
- batch,
887
- y_hat,
888
- accum_batch_num,
889
- timesteps_to_plot=interval,
890
- plot_suffix=f"timestep_{interval}",
891
- )
892
-
893
- del self._val_y_hats
894
- del self._val_batches
895
-
896
- return logged_losses
897
-
898
- def on_validation_epoch_end(self):
899
- """Run on epoch end"""
900
-
901
- try:
902
- # join together validation results, and save to wandb
903
- validation_results_df = pd.concat(self.validation_epoch_results)
904
- validation_results_df["error"] = (
905
- validation_results_df["y"] - validation_results_df["y_quantile_0.5"]
906
- )
907
-
908
- if isinstance(self.logger, pl.loggers.WandbLogger):
909
- # log error distribution metrics
910
- wandb.log(
911
- {
912
- "2nd_percentile_median_forecast_error": validation_results_df[
913
- "error"
914
- ].quantile(0.02),
915
- "5th_percentile_median_forecast_error": validation_results_df[
916
- "error"
917
- ].quantile(0.05),
918
- "95th_percentile_median_forecast_error": validation_results_df[
919
- "error"
920
- ].quantile(0.95),
921
- "98th_percentile_median_forecast_error": validation_results_df[
922
- "error"
923
- ].quantile(0.98),
924
- "95th_percentile_median_forecast_absolute_error": abs(
925
- validation_results_df["error"]
926
- ).quantile(0.95),
927
- "98th_percentile_median_forecast_absolute_error": abs(
928
- validation_results_df["error"]
929
- ).quantile(0.98),
930
- }
931
- )
932
- # saving validation result csvs
933
- if self.save_validation_results_csv:
934
- with tempfile.TemporaryDirectory() as tempdir:
935
- filename = os.path.join(tempdir, f"validation_results_{self.current_epoch}.csv")
936
- validation_results_df.to_csv(filename, index=False)
937
-
938
- # make and log wand artifact
939
- validation_artifact = wandb.Artifact(
940
- f"validation_results_epoch_{self.current_epoch}", type="dataset"
941
- )
942
- validation_artifact.add_file(filename)
943
- wandb.log_artifact(validation_artifact)
944
-
945
- except Exception as e:
946
- print("Failed to log validation results to wandb")
947
- print(e)
948
-
949
- self.validation_epoch_results = []
950
- horizon_maes_dict = self._horizon_maes.flush()
951
-
952
- # Create the horizon accuracy curve
953
- if isinstance(self.logger, pl.loggers.WandbLogger):
954
- per_step_losses = [[i, horizon_maes_dict[i]] for i in range(self.forecast_len)]
955
- try:
956
- table = wandb.Table(data=per_step_losses, columns=["horizon_step", "MAE"])
957
- wandb.log(
958
- {
959
- "horizon_loss_curve": wandb.plot.line(
960
- table, "horizon_step", "MAE", title="Horizon loss curve"
961
- )
962
- },
963
- )
964
- except Exception as e:
965
- print("Failed to log horizon_loss_curve to wandb")
966
- print(e)
967
-
968
- def configure_optimizers(self):
969
- """Configure the optimizers using learning rate found with LR finder if used"""
970
- if self.lr is not None:
971
- # Use learning rate found by learning rate finder callback
972
- self._optimizer.lr = self.lr
973
- return self._optimizer(self)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/baseline/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Baselines"""
 
 
pvnet/models/baseline/last_value.py DELETED
@@ -1,42 +0,0 @@
1
- """Persistence model"""
2
-
3
-
4
- import pvnet
5
- from pvnet.models.base_model import BaseModel
6
- from pvnet.optimizers import AbstractOptimizer
7
-
8
-
9
- class Model(BaseModel):
10
- """Simple baseline model that takes the last gsp yield value and copies it forward."""
11
-
12
- name = "last_value"
13
-
14
- def __init__(
15
- self,
16
- forecast_minutes: int = 12,
17
- history_minutes: int = 6,
18
- optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
19
- ):
20
- """Simple baseline model that takes the last gsp yield value and copies it forward.
21
-
22
- Args:
23
- history_minutes (int): Length of the GSP history period in minutes
24
- forecast_minutes (int): Length of the GSP forecast period in minutes
25
- optimizer (AbstractOptimizer): Optimizer
26
- """
27
-
28
- super().__init__(history_minutes, forecast_minutes, optimizer)
29
- self.save_hyperparameters()
30
-
31
- def forward(self, x: dict):
32
- """Run model forward on dict batch of data"""
33
- # Shape: batch_size, seq_length, n_sites
34
- gsp_yield = x["gsp"]
35
-
36
- # take the last value non forecaster value and the first in the pv yeild
37
- # (this is the pv site we are preditcting for)
38
- y_hat = gsp_yield[:, -self.forecast_len - 1]
39
-
40
- # expand the last valid forward n predict steps
41
- out = y_hat.unsqueeze(1).repeat(1, self.forecast_len)
42
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/baseline/readme.md DELETED
@@ -1,5 +0,0 @@
1
- # Baseline Models
2
-
3
- - `last_value` - Forecast the sample last historical PV yeild for every forecast step
4
- - `single_value` - Learns a single value estimate and predicts this value for every input and every
5
- forecast step.
 
 
 
 
 
 
pvnet/models/baseline/single_value.py DELETED
@@ -1,36 +0,0 @@
1
- """Average value model"""
2
- import torch
3
- from torch import nn
4
-
5
- import pvnet
6
- from pvnet.models.base_model import BaseModel
7
- from pvnet.optimizers import AbstractOptimizer
8
-
9
-
10
- class Model(BaseModel):
11
- """Simple baseline model that predicts always the same value."""
12
-
13
- name = "single_value"
14
-
15
- def __init__(
16
- self,
17
- forecast_minutes: int = 120,
18
- history_minutes: int = 60,
19
- optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
20
- ):
21
- """Simple baseline model that predicts always the same value.
22
-
23
- Args:
24
- history_minutes (int): Length of the GSP history period in minutes
25
- forecast_minutes (int): Length of the GSP forecast period in minutes
26
- optimizer (AbstractOptimizer): Optimizer
27
- """
28
- super().__init__(history_minutes, forecast_minutes, optimizer)
29
- self._value = nn.Parameter(torch.zeros(1), requires_grad=True)
30
- self.save_hyperparameters()
31
-
32
- def forward(self, x: dict):
33
- """Run model forward on dict batch of data"""
34
- # Returns a single value at all steps
35
- y_hat = torch.zeros_like(x["gsp"][:, : self.forecast_len]) + self._value
36
- return y_hat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/ensemble.py DELETED
@@ -1,74 +0,0 @@
1
- """Model which uses mutliple prediction heads"""
2
- from typing import Optional
3
-
4
- import torch
5
- from torch import nn
6
-
7
- from pvnet.models.base_model import BaseModel
8
-
9
-
10
- class Ensemble(BaseModel):
11
- """Ensemble of PVNet models"""
12
-
13
- def __init__(
14
- self,
15
- model_list: list[BaseModel],
16
- weights: Optional[list[float]] = None,
17
- ):
18
- """Ensemble of PVNet models
19
-
20
- Args:
21
- model_list: A list of PVNet models to ensemble
22
- weights: A list of weighting to apply to each model. If None, the models are weighted
23
- equally.
24
- """
25
-
26
- # Surface check all the models are compatible
27
- output_quantiles = []
28
- history_minutes = []
29
- forecast_minutes = []
30
- target_key = []
31
- interval_minutes = []
32
-
33
- # Get some model properties from each model
34
- for model in model_list:
35
- output_quantiles.append(model.output_quantiles)
36
- history_minutes.append(model.history_minutes)
37
- forecast_minutes.append(model.forecast_minutes)
38
- target_key.append(model._target_key)
39
- interval_minutes.append(model.interval_minutes)
40
-
41
- # Check these properties are all the same
42
- for param_list in [
43
- output_quantiles,
44
- history_minutes,
45
- forecast_minutes,
46
- target_key,
47
- interval_minutes,
48
- ]:
49
- assert all([p == param_list[0] for p in param_list]), param_list
50
-
51
- super().__init__(
52
- history_minutes=history_minutes[0],
53
- forecast_minutes=forecast_minutes[0],
54
- optimizer=None,
55
- output_quantiles=output_quantiles[0],
56
- target_key=target_key[0],
57
- interval_minutes=interval_minutes[0],
58
- )
59
-
60
- self.model_list = nn.ModuleList(model_list)
61
-
62
- if weights is None:
63
- weights = torch.ones(len(model_list)) / len(model_list)
64
- else:
65
- assert len(weights) == len(model_list)
66
- weights = torch.Tensor(weights) / sum(weights)
67
- self.weights = nn.Parameter(weights, requires_grad=False)
68
-
69
- def forward(self, batch):
70
- """Run the model forward"""
71
- y_hat = 0
72
- for weight, model in zip(self.weights, self.model_list):
73
- y_hat = model(batch) * weight + y_hat
74
- return y_hat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/model_cards/pv_india_model_card_template.md DELETED
@@ -1,56 +0,0 @@
1
- ---
2
- {{ card_data }}
3
- ---
4
-
5
-
6
-
7
-
8
-
9
-
10
- # PVNet India
11
-
12
- ## Model Description
13
-
14
- <!-- Provide a longer summary of what this model is/does. -->
15
- This model class uses numerical weather predictions from providers such as ECMWF to forecast the PV power in North West India over the next 48 hours. More information can be found in the model repo [1] and experimental notes [here](https://github.com/openclimatefix/PVNet/tree/main/experiments/india).
16
-
17
-
18
- - **Developed by:** openclimatefix
19
- - **Model type:** Fusion model
20
- - **Language(s) (NLP):** en
21
- - **License:** mit
22
-
23
-
24
- # Training Details
25
-
26
- ## Data
27
-
28
- <!-- This should link to a Data Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
29
-
30
- The model is trained on data from 2019-2022 and validated on data from 2022-2023. See experimental notes [here](https://github.com/openclimatefix/PVNet/tree/main/experiments/india)
31
-
32
-
33
- ### Preprocessing
34
-
35
- Data is prepared with the `ocf_data_sampler/torch_datasets/datasets/site` Dataset [2].
36
-
37
-
38
- ## Results
39
-
40
- The training logs for the current model can be found here:
41
- {{ wandb_links }}
42
-
43
-
44
- ### Hardware
45
-
46
- Trained on a single NVIDIA Tesla T4
47
-
48
- ### Software
49
-
50
- This model was trained using the following Open Climate Fix packages:
51
-
52
- - [1] https://github.com/openclimatefix/PVNet
53
- - [2] https://github.com/openclimatefix/ocf-data-sampler
54
-
55
- The versions of these packages can be found below:
56
- {{ package_versions }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/model_cards/pv_uk_regional_model_card_template.md DELETED
@@ -1,59 +0,0 @@
1
- ---
2
- {{ card_data }}
3
- ---
4
-
5
-
6
-
7
-
8
-
9
-
10
- # PVNet2
11
-
12
- ## Model Description
13
-
14
- <!-- Provide a longer summary of what this model is/does. -->
15
- This model class uses satellite data, numerical weather predictions, and recent Grid Service Point( GSP) PV power output to forecast the near-term (~8 hours) PV power output at all GSPs. More information can be found in the model repo [1] and experimental notes in [this google doc](https://docs.google.com/document/d/1fbkfkBzp16WbnCg7RDuRDvgzInA6XQu3xh4NCjV-WDA/edit?usp=sharing).
16
-
17
- - **Developed by:** openclimatefix
18
- - **Model type:** Fusion model
19
- - **Language(s) (NLP):** en
20
- - **License:** mit
21
-
22
-
23
- # Training Details
24
-
25
- ## Data
26
-
27
- <!-- This should link to a Data Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
28
-
29
- The model is trained on data from 2019-2022 and validated on data from 2022-2023. See experimental notes in the [the google doc](https://docs.google.com/document/d/1fbkfkBzp16WbnCg7RDuRDvgzInA6XQu3xh4NCjV-WDA/edit?usp=sharing) for more details.
30
-
31
-
32
- ### Preprocessing
33
-
34
- Data is prepared with the `ocf_data_sampler/torch_datasets/datasets/pvnet_uk` Dataset [2].
35
-
36
-
37
- ## Results
38
-
39
- The training logs for the current model can be found here:
40
- {{ wandb_links }}
41
-
42
- The training logs for all model runs of PVNet2 can be found [here](https://wandb.ai/openclimatefix/pvnet2.1).
43
-
44
- Some experimental notes can be found at in [the google doc](https://docs.google.com/document/d/1fbkfkBzp16WbnCg7RDuRDvgzInA6XQu3xh4NCjV-WDA/edit?usp=sharing)
45
-
46
-
47
- ### Hardware
48
-
49
- Trained on a single NVIDIA Tesla T4
50
-
51
- ### Software
52
-
53
- This model was trained using the following Open Climate Fix packages:
54
-
55
- - [1] https://github.com/openclimatefix/PVNet
56
- - [2] https://github.com/openclimatefix/ocf-data-sampler
57
-
58
- The versions of these packages can be found below:
59
- {{ package_versions }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/model_cards/wind_india_model_card_template.md DELETED
@@ -1,56 +0,0 @@
1
- ---
2
- {{ card_data }}
3
- ---
4
-
5
-
6
-
7
-
8
-
9
-
10
- # WindNet
11
-
12
- ## Model Description
13
-
14
- <!-- Provide a longer summary of what this model is/does. -->
15
- This model class uses numerical weather predictions from providers such as ECMWF to forecast the wind power in North West India over the next 48 hours at 15 minute granularity. More information can be found in the model repo [1] and experimental notes [here](https://github.com/openclimatefix/PVNet/tree/main/experiments/india).
16
-
17
-
18
- - **Developed by:** openclimatefix
19
- - **Model type:** Fusion model
20
- - **Language(s) (NLP):** en
21
- - **License:** mit
22
-
23
-
24
- # Training Details
25
-
26
- ## Data
27
-
28
- <!-- This should link to a Data Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
29
-
30
- The model is trained on data from 2019-2022 and validated on data from 2022-2023. See experimental notes [here](https://github.com/openclimatefix/PVNet/tree/main/experiments/india)
31
-
32
-
33
- ### Preprocessing
34
-
35
- Data is prepared with the `ocf_data_sampler/torch_datasets/datasets/site` Dataset [2].
36
-
37
-
38
- ## Results
39
-
40
- The training logs for the current model can be found here:
41
- {{ wandb_links }}
42
-
43
-
44
- ### Hardware
45
-
46
- Trained on a single NVIDIA Tesla T4
47
-
48
- ### Software
49
-
50
- This model was trained using the following Open Climate Fix packages:
51
-
52
- - [1] https://github.com/openclimatefix/PVNet
53
- - [2] https://github.com/openclimatefix/ocf-data-sampler
54
-
55
- The versions of these packages can be found below:
56
- {{ package_versions }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/multimodal/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Multimodal Models"""
 
 
pvnet/models/multimodal/basic_blocks.py DELETED
@@ -1,104 +0,0 @@
1
- """Basic layers for composite models"""
2
-
3
- import warnings
4
-
5
- import torch
6
- from torch import _VF, nn
7
-
8
-
9
- class ImageEmbedding(nn.Module):
10
- """A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs."""
11
-
12
- def __init__(self, num_embeddings, sequence_length, image_size_pixels, **kwargs):
13
- """A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs.
14
-
15
- The embedding is a single 2D image and is appended at each step in the 1st dimension
16
- (assumed to be time).
17
-
18
- Args:
19
- num_embeddings: Size of the dictionary of embeddings
20
- sequence_length: The time sequence length of the data.
21
- image_size_pixels: The spatial size of the image. Assumed square.
22
- **kwargs: See `torch.nn.Embedding` for more possible arguments.
23
- """
24
- super().__init__()
25
- self.image_size_pixels = image_size_pixels
26
- self.sequence_length = sequence_length
27
- self._embed = nn.Embedding(
28
- num_embeddings=num_embeddings,
29
- embedding_dim=image_size_pixels * image_size_pixels,
30
- **kwargs,
31
- )
32
-
33
- def forward(self, x, id):
34
- """Append ID embedding to image"""
35
- emb = self._embed(id)
36
- emb = emb.reshape((-1, 1, 1, self.image_size_pixels, self.image_size_pixels))
37
- emb = emb.repeat(1, 1, self.sequence_length, 1, 1)
38
- x = torch.cat((x, emb), dim=1)
39
- return x
40
-
41
-
42
- class CompleteDropoutNd(nn.Module):
43
- """A layer used to completely drop out all elements of a N-dimensional sample.
44
-
45
- Each sample will be zeroed out independently on every forward call with probability `p` using
46
- samples from a Bernoulli distribution.
47
-
48
- """
49
-
50
- __constants__ = ["p", "inplace", "n_dim"]
51
- p: float
52
- inplace: bool
53
- n_dim: int
54
-
55
- def __init__(self, n_dim, p=0.5, inplace=False):
56
- """A layer used to completely drop out all elements of a N-dimensional sample.
57
-
58
- Args:
59
- n_dim: Number of dimensions of each sample not including channels. E.g. a sample with
60
- shape (channel, time, height, width) would use `n_dim=3`.
61
- p: probability of a channel to be zeroed. Default: 0.5
62
- training: apply dropout if is `True`. Default: `True`
63
- inplace: If set to `True`, will do this operation in-place. Default: `False`
64
- """
65
- super().__init__()
66
- if p < 0 or p > 1:
67
- raise ValueError(
68
- "dropout probability has to be between 0 and 1, " "but got {}".format(p)
69
- )
70
- self.p = p
71
- self.inplace = inplace
72
- self.n_dim = n_dim
73
-
74
- def forward(self, input: torch.Tensor) -> torch.Tensor:
75
- """Run dropout"""
76
- p = self.p
77
- inp_dim = input.dim()
78
-
79
- if inp_dim not in (self.n_dim + 1, self.n_dim + 2):
80
- warn_msg = (
81
- f"CompleteDropoutNd: Received a {inp_dim}-D input. Expected either a single sample"
82
- f" with {self.n_dim+1} dimensions, or a batch of samples with {self.n_dim+2}"
83
- " dimensions."
84
- )
85
- warnings.warn(warn_msg)
86
-
87
- is_batched = inp_dim == self.n_dim + 2
88
- if not is_batched:
89
- input = input.unsqueeze_(0) if self.inplace else input.unsqueeze(0)
90
-
91
- input = input.unsqueeze_(1) if self.inplace else input.unsqueeze(1)
92
-
93
- result = (
94
- _VF.feature_dropout_(input, p, self.training)
95
- if self.inplace
96
- else _VF.feature_dropout(input, p, self.training)
97
- )
98
-
99
- result = result.squeeze_(1) if self.inplace else result.squeeze(1)
100
-
101
- if not is_batched:
102
- result = result.squeeze_(0) if self.inplace else result.squeeze(0)
103
-
104
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/multimodal/encoders/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Submodels to encode satellite and NWP inputs"""
 
 
pvnet/models/multimodal/encoders/basic_blocks.py DELETED
@@ -1,217 +0,0 @@
1
- """Basic blocks for image sequence encoders"""
2
- from abc import ABCMeta, abstractmethod
3
-
4
- import torch
5
- from torch import nn
6
-
7
-
8
- class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta):
9
- """Abstract class for NWP/satellite encoder.
10
-
11
- The encoder will take an input of shape (batch_size, sequence_length, channels, height, width)
12
- and return an output of shape (batch_size, out_features).
13
- """
14
-
15
- def __init__(
16
- self,
17
- sequence_length: int,
18
- image_size_pixels: int,
19
- in_channels: int,
20
- out_features: int,
21
- ):
22
- """Abstract class for NWP/satellite encoder.
23
-
24
- Args:
25
- sequence_length: The time sequence length of the data.
26
- image_size_pixels: The spatial size of the image. Assumed square.
27
- in_channels: Number of input channels.
28
- out_features: Number of output features.
29
- """
30
- super().__init__()
31
- self.out_features = out_features
32
- self.image_size_pixels = image_size_pixels
33
- self.sequence_length = sequence_length
34
-
35
- @abstractmethod
36
- def forward(self):
37
- """Run model forward"""
38
- pass
39
-
40
-
41
- class ResidualConv3dBlock(nn.Module):
42
- """Fully-connected deep network based on ResNet architecture.
43
-
44
- Internally, this network uses ELU activations throughout the residual blocks.
45
- """
46
-
47
- def __init__(
48
- self,
49
- in_channels,
50
- n_layers: int = 2,
51
- dropout_frac: float = 0.0,
52
- ):
53
- """Fully-connected deep network based on ResNet architecture.
54
-
55
- Args:
56
- in_channels: Number of input channels.
57
- n_layers: Number of layers in residual pathway.
58
- dropout_frac: Probability of an element to be zeroed.
59
- """
60
- super().__init__()
61
-
62
- layers = []
63
- for i in range(n_layers):
64
- layers += [
65
- nn.ELU(),
66
- nn.Conv3d(
67
- in_channels=in_channels,
68
- out_channels=in_channels,
69
- kernel_size=(3, 3, 3),
70
- padding=(1, 1, 1),
71
- ),
72
- nn.Dropout3d(p=dropout_frac),
73
- ]
74
-
75
- self.model = nn.Sequential(*layers)
76
-
77
- def forward(self, x):
78
- """Run residual connection"""
79
- return self.model(x) + x
80
-
81
-
82
- class ResidualConv3dBlock2(nn.Module):
83
- """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1].
84
-
85
- This was the best performing residual block tested in the study. This implementation differs
86
- from that block just by using LeakyReLU activation to avoid dead neurons, and by including
87
- optional dropout in the residual branch. This is also a 3D fully connected layer residual block
88
- rather than a 2D convolutional block.
89
-
90
- Sources:
91
- [1] https://arxiv.org/pdf/1603.05027.pdf
92
- """
93
-
94
- def __init__(
95
- self,
96
- in_channels: int,
97
- n_layers: int = 2,
98
- dropout_frac: float = 0.0,
99
- batch_norm: bool = True,
100
- ):
101
- """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1].
102
-
103
- Sources:
104
- [1] https://arxiv.org/pdf/1603.05027.pdf
105
-
106
- Args:
107
- in_channels: Number of input channels.
108
- n_layers: Number of layers in residual pathway.
109
- dropout_frac: Probability of an element to be zeroed.
110
- batch_norm: Whether to use batchnorm
111
- """
112
- super().__init__()
113
-
114
- layers = []
115
- for i in range(n_layers):
116
- if batch_norm:
117
- layers.append(nn.BatchNorm3d(in_channels))
118
- layers.extend(
119
- [
120
- nn.Dropout3d(p=dropout_frac),
121
- nn.LeakyReLU(),
122
- nn.Conv3d(
123
- in_channels=in_channels,
124
- out_channels=in_channels,
125
- kernel_size=(3, 3, 3),
126
- padding=(1, 1, 1),
127
- ),
128
- ]
129
- )
130
-
131
- self.model = nn.Sequential(*layers)
132
-
133
- def forward(self, x):
134
- """Run model forward"""
135
- return self.model(x) + x
136
-
137
-
138
- class ImageSequenceEncoder(nn.Module):
139
- """Simple network which independently encodes each image in a sequence into 1D features"""
140
-
141
- def __init__(
142
- self,
143
- image_size_pixels: int,
144
- in_channels: int,
145
- number_of_conv2d_layers: int = 4,
146
- conv2d_channels: int = 32,
147
- fc_features: int = 128,
148
- ):
149
- """Simple network which independently encodes each image in a sequence into 1D features.
150
-
151
- For input image with shape [N, C, L, H, W] the output is of shape [N, L, fc_features] where
152
- N is number of samples in batch, C is the number of input channels, L is the length of the
153
- sequence, and H and W are the height and width.
154
-
155
- Args:
156
- image_size_pixels: The spatial size of the image. Assumed square.
157
- in_channels: Number of input channels.
158
- number_of_conv2d_layers: Number of convolution 2D layers that are used.
159
- conv2d_channels: Number of channels used in each conv2d layer.
160
- fc_features: Number of output nodes for each image in each sequence.
161
- """
162
- super().__init__()
163
-
164
- # Check that the output shape of the convolutional layers will be at least 1x1
165
- cnn_spatial_output_size = image_size_pixels - 2 * number_of_conv2d_layers
166
- if not (cnn_spatial_output_size >= 1):
167
- raise ValueError(
168
- f"cannot use this many conv2d layers ({number_of_conv2d_layers}) with this input "
169
- f"spatial size ({image_size_pixels})"
170
- )
171
-
172
- conv_layers = []
173
-
174
- conv_layers += [
175
- nn.Conv2d(
176
- in_channels=in_channels,
177
- out_channels=conv2d_channels,
178
- kernel_size=3,
179
- padding=0,
180
- ),
181
- nn.ELU(),
182
- ]
183
- for i in range(0, number_of_conv2d_layers - 1):
184
- conv_layers += [
185
- nn.Conv2d(
186
- in_channels=conv2d_channels,
187
- out_channels=conv2d_channels,
188
- kernel_size=3,
189
- padding=0,
190
- ),
191
- nn.ELU(),
192
- ]
193
-
194
- self.conv_layers = nn.Sequential(*conv_layers)
195
-
196
- self.final_block = nn.Sequential(
197
- nn.Linear(
198
- in_features=(cnn_spatial_output_size**2) * conv2d_channels,
199
- out_features=fc_features,
200
- ),
201
- nn.ELU(),
202
- )
203
-
204
- def forward(self, x):
205
- """Run model forward"""
206
- batch_size, channel, seq_len, height, width = x.shape
207
-
208
- x = torch.swapaxes(x, 1, 2)
209
- x = x.reshape(batch_size * seq_len, channel, height, width)
210
-
211
- out = self.conv_layers(x)
212
- out = out.reshape(batch_size * seq_len, -1)
213
-
214
- out = self.final_block(out)
215
- out = out.reshape(batch_size, seq_len, -1)
216
-
217
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/multimodal/encoders/encoders2d.py DELETED
@@ -1,413 +0,0 @@
1
- """Encoder modules for the satellite/NWP data.
2
-
3
- These networks naively stack the sequences into extra channels before putting through their
4
- architectures.
5
- """
6
-
7
- from functools import partial
8
- from typing import Any, Callable, List, Optional, Sequence, Type, Union
9
-
10
- import torch
11
- from torch import Tensor, nn
12
- from torchvision.models.convnext import CNBlock, CNBlockConfig, LayerNorm2d
13
- from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
14
- from torchvision.ops.misc import Conv2dNormActivation
15
- from torchvision.utils import _log_api_usage_once
16
-
17
- from pvnet.models.multimodal.encoders.basic_blocks import AbstractNWPSatelliteEncoder
18
-
19
-
20
- class NaiveEfficientNet(AbstractNWPSatelliteEncoder):
21
- """An implementation of EfficientNet from `efficientnet_pytorch`.
22
-
23
- This model is quite naive, and just stacks the sequence into channels.
24
- """
25
-
26
- def __init__(
27
- self,
28
- sequence_length: int,
29
- image_size_pixels: int,
30
- in_channels: int,
31
- out_features: int,
32
- model_name: str = "efficientnet-b0",
33
- ):
34
- """An implementation of EfficientNet from `efficientnet_pytorch`.
35
-
36
- This model is quite naive, and just stacks the sequence into channels.
37
-
38
- Args:
39
- sequence_length: The time sequence length of the data.
40
- image_size_pixels: The spatial size of the image. Assumed square.
41
- in_channels: Number of input channels.
42
- out_features: Number of output features.
43
- model_name: Name of EfficientNet model to construct.
44
-
45
- Notes:
46
- The `efficientnet_pytorch` package must be installed to use `EncoderNaiveEfficientNet`.
47
- See https://github.com/lukemelas/EfficientNet-PyTorch for install instructions.
48
- """
49
-
50
- from efficientnet_pytorch import EfficientNet
51
-
52
- super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
53
-
54
- self.model = EfficientNet.from_name(
55
- model_name,
56
- in_channels=in_channels * sequence_length,
57
- image_size=image_size_pixels,
58
- num_classes=out_features,
59
- )
60
-
61
- def forward(self, x):
62
- """Run model forward"""
63
- bs, s, c, h, w = x.shape
64
- x = x.reshape((bs, s * c, h, w))
65
- return self.model(x)
66
-
67
-
68
- class NaiveResNet(nn.Module):
69
- """A ResNet model modified from one in torchvision [1].
70
-
71
- Modified allow different number of input channels. This model is quite naive, and just stacks
72
- the sequence into channels.
73
-
74
- Example use:
75
- ```
76
- resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
77
- resnet50 = ResNet(Bottleneck, [3, 4, 6, 3])
78
- ```
79
-
80
- Sources:
81
- [1] https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
82
- [2] https://pytorch.org/hub/pytorch_vision_resnet
83
- """
84
-
85
- def __init__(
86
- self,
87
- sequence_length: int,
88
- image_size_pixels: int,
89
- in_channels: int,
90
- out_features: int,
91
- layers: List[int] = [2, 2, 2, 2],
92
- block: str = "bottleneck",
93
- zero_init_residual: bool = False,
94
- groups: int = 1,
95
- width_per_group: int = 64,
96
- replace_stride_with_dilation: Optional[List[bool]] = None,
97
- norm_layer: Optional[Callable[..., nn.Module]] = None,
98
- ):
99
- """A ResNet model modified from one in torchvision [1].
100
-
101
- Args:
102
- sequence_length: The time sequence length of the data.
103
- image_size_pixels: The spatial size of the image. Assumed square.
104
- in_channels: Number of input channels.
105
- out_features: Number of output features.
106
- layers: See [1] and [2].
107
- block: See [1] and [2].
108
- zero_init_residual: See [1] and [2].
109
- groups: See [1] and [2].
110
- width_per_group: See [1] and [2].
111
- replace_stride_with_dilation: See [1] and [2].
112
- norm_layer: See [1] and [2].
113
-
114
- Sources:
115
- [1] https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
116
- [2] https://pytorch.org/hub/pytorch_vision_resnet
117
- """
118
- super().__init__()
119
- _log_api_usage_once(self)
120
- if norm_layer is None:
121
- norm_layer = nn.BatchNorm2d
122
- self._norm_layer = norm_layer
123
-
124
- # Account for stacking sequences into more channels
125
- in_channels = in_channels * sequence_length
126
-
127
- block = {
128
- "basic": BasicBlock,
129
- "bottleneck": Bottleneck,
130
- }[block]
131
-
132
- self.inplanes = 64
133
- self.dilation = 1
134
- if replace_stride_with_dilation is None:
135
- # each element in the tuple indicates if we should replace
136
- # the 2x2 stride with a dilated convolution instead
137
- replace_stride_with_dilation = [False, False, False]
138
- if len(replace_stride_with_dilation) != 3:
139
- raise ValueError(
140
- "replace_stride_with_dilation should be None "
141
- f"or a 3-element tuple, got {replace_stride_with_dilation}"
142
- )
143
- self.groups = groups
144
- self.base_width = width_per_group
145
- self.conv1 = nn.Conv2d(
146
- in_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
147
- )
148
- self.bn1 = norm_layer(self.inplanes)
149
- self.relu = nn.ReLU(inplace=True)
150
- # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
151
- self.layer1 = self._make_layer(block, 64, layers[0])
152
- self.layer2 = self._make_layer(
153
- block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
154
- )
155
- self.layer3 = self._make_layer(
156
- block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
157
- )
158
- self.layer4 = self._make_layer(
159
- block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
160
- )
161
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
162
- self.fc = nn.Linear(512 * block.expansion, out_features)
163
- self.final_act = nn.LeakyReLU()
164
-
165
- for m in self.modules():
166
- if isinstance(m, nn.Conv2d):
167
- nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
168
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
169
- nn.init.constant_(m.weight, 1)
170
- nn.init.constant_(m.bias, 0)
171
-
172
- # Zero-initialize the last BN in each residual branch,
173
- # so that the residual branch starts with zeros, and each residual block behaves like an
174
- # identity. This improves the model by 0.2~0.3% according to
175
- # https://arxiv.org/abs/1706.02677
176
- if zero_init_residual:
177
- for m in self.modules():
178
- if isinstance(m, Bottleneck) and m.bn3.weight is not None:
179
- nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
180
- elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
181
- nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
182
-
183
- def _make_layer(
184
- self,
185
- block: Type[Union[BasicBlock, Bottleneck]],
186
- planes: int,
187
- blocks: int,
188
- stride: int = 1,
189
- dilate: bool = False,
190
- ) -> nn.Sequential:
191
- norm_layer = self._norm_layer
192
- downsample = None
193
- previous_dilation = self.dilation
194
- if dilate:
195
- self.dilation *= stride
196
- stride = 1
197
- if stride != 1 or self.inplanes != planes * block.expansion:
198
- downsample = nn.Sequential(
199
- conv1x1(self.inplanes, planes * block.expansion, stride),
200
- norm_layer(planes * block.expansion),
201
- )
202
-
203
- layers = []
204
- layers.append(
205
- block(
206
- self.inplanes,
207
- planes,
208
- stride,
209
- downsample,
210
- self.groups,
211
- self.base_width,
212
- previous_dilation,
213
- norm_layer,
214
- )
215
- )
216
- self.inplanes = planes * block.expansion
217
- for _ in range(1, blocks):
218
- layers.append(
219
- block(
220
- self.inplanes,
221
- planes,
222
- groups=self.groups,
223
- base_width=self.base_width,
224
- dilation=self.dilation,
225
- norm_layer=norm_layer,
226
- )
227
- )
228
-
229
- return nn.Sequential(*layers)
230
-
231
- def _forward_impl(self, x: Tensor) -> Tensor:
232
- # See note [TorchScript super()]
233
- x = self.conv1(x)
234
- x = self.bn1(x)
235
- x = self.relu(x)
236
- # x = self.maxpool(x)
237
-
238
- x = self.layer1(x)
239
- x = self.layer2(x)
240
- x = self.layer3(x)
241
- x = self.layer4(x)
242
-
243
- x = self.avgpool(x)
244
- x = torch.flatten(x, 1)
245
- x = self.fc(x)
246
- x = self.final_act(x)
247
-
248
- return x
249
-
250
- def forward(self, x: Tensor) -> Tensor:
251
- """Run model forward"""
252
- bs, s, c, h, w = x.shape
253
- x = x.reshape((bs, s * c, h, w))
254
- return self._forward_impl(x)
255
-
256
-
257
- class NaiveConvNeXt(nn.Module):
258
- """A NaiveConvNeXt model [1] modified from one in torchvision [2].
259
-
260
- Mopdified to allow different number of input channels, and smaller spatial inputs. This model is
261
- quite naive, and just stacks the sequence into channels.
262
-
263
- Example usage:
264
- ```
265
- block_setting = [
266
- CNBlockConfig(96, 192, 3),
267
- CNBlockConfig(192, 384, 3),
268
- CNBlockConfig(384, 768, 9),
269
- CNBlockConfig(768, None, 3),
270
- ]
271
-
272
- sequence_len = 12
273
- channels = 2
274
- pixels=24
275
-
276
- convnext_tiny = ConvNeXt(
277
- sequence_length=12,
278
- image_size_pixels=24,
279
- in_channels=2,
280
- out_features=128,
281
- block_setting=block_setting,
282
- stochastic_depth_prob=0.1,
283
- )
284
- ```
285
-
286
- Sources:
287
- [1] https://arxiv.org/abs/2201.03545
288
- [2] https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py
289
- [3] https://pytorch.org/vision/main/models/convnext.html
290
-
291
- """
292
-
293
- def __init__(
294
- self,
295
- sequence_length: int,
296
- image_size_pixels: int,
297
- in_channels: int,
298
- out_features: int,
299
- block_setting: List[CNBlockConfig],
300
- stochastic_depth_prob: float = 0.0,
301
- layer_scale: float = 1e-6,
302
- block: Optional[Callable[..., nn.Module]] = None,
303
- norm_layer: Optional[Callable[..., nn.Module]] = None,
304
- **kwargs: Any,
305
- ) -> None:
306
- """A ConvNeXt model [1] modified from one in torchvision [2].
307
-
308
- Args:
309
- sequence_length: The time sequence length of the data.
310
- image_size_pixels: The spatial size of the image. Assumed square.
311
- in_channels: Number of input channels.
312
- out_features: Number of output features.
313
- block_setting: See [2] and [3].
314
- stochastic_depth_prob: See [2] and [3].
315
- layer_scale: See [2] and [3].
316
- block: See [2] and [3].
317
- norm_layer: See [2] and [3].
318
- **kwargs: See [2] and [3].
319
-
320
- Sources:
321
- [1] https://arxiv.org/abs/2201.03545
322
- [2] https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py
323
- [3] https://pytorch.org/vision/main/models/convnext.html
324
- """
325
- super().__init__()
326
- _log_api_usage_once(self)
327
-
328
- if not block_setting:
329
- raise ValueError("The block_setting should not be empty")
330
- elif not (
331
- isinstance(block_setting, Sequence)
332
- and all([isinstance(s, CNBlockConfig) for s in block_setting])
333
- ):
334
- raise TypeError("The block_setting should be List[CNBlockConfig]")
335
-
336
- if block is None:
337
- block = CNBlock
338
-
339
- if norm_layer is None:
340
- norm_layer = partial(LayerNorm2d, eps=1e-6)
341
-
342
- layers: List[nn.Module] = []
343
-
344
- # Account for stacking sequences into more channels
345
- in_channels = in_channels * sequence_length
346
-
347
- # Stem
348
- firstconv_output_channels = block_setting[0].input_channels
349
- layers.append(
350
- Conv2dNormActivation(
351
- in_channels,
352
- firstconv_output_channels,
353
- kernel_size=2,
354
- stride=2,
355
- padding=0,
356
- norm_layer=norm_layer,
357
- activation_layer=None,
358
- bias=True,
359
- )
360
- )
361
-
362
- total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
363
- stage_block_id = 0
364
- for cnf in block_setting:
365
- # Bottlenecks
366
- stage: List[nn.Module] = []
367
- for _ in range(cnf.num_layers):
368
- # adjust stochastic depth probability based on the depth of the stage block
369
- sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
370
- stage.append(block(cnf.input_channels, layer_scale, sd_prob))
371
- stage_block_id += 1
372
- layers.append(nn.Sequential(*stage))
373
- if cnf.out_channels is not None:
374
- # Downsampling
375
- layers.append(
376
- nn.Sequential(
377
- norm_layer(cnf.input_channels),
378
- nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
379
- )
380
- )
381
-
382
- self.features = nn.Sequential(*layers)
383
- self.avgpool = nn.AdaptiveAvgPool2d(1)
384
-
385
- lastblock = block_setting[-1]
386
- lastconv_output_channels = (
387
- lastblock.out_channels
388
- if lastblock.out_channels is not None
389
- else lastblock.input_channels
390
- )
391
- self.classifier = nn.Sequential(
392
- norm_layer(lastconv_output_channels),
393
- nn.Flatten(1),
394
- nn.Linear(lastconv_output_channels, out_features),
395
- )
396
-
397
- for m in self.modules():
398
- if isinstance(m, (nn.Conv2d, nn.Linear)):
399
- nn.init.trunc_normal_(m.weight, std=0.02)
400
- if m.bias is not None:
401
- nn.init.zeros_(m.bias)
402
-
403
- def _forward_impl(self, x: Tensor) -> Tensor:
404
- x = self.features(x)
405
- x = self.avgpool(x)
406
- x = self.classifier(x)
407
- return x
408
-
409
- def forward(self, x: Tensor) -> Tensor:
410
- """Run model forward"""
411
- bs, s, c, h, w = x.shape
412
- x = x.reshape((bs, s * c, h, w))
413
- return self._forward_impl(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/multimodal/encoders/encoders3d.py DELETED
@@ -1,402 +0,0 @@
1
- """Encoder modules for the satellite/NWP data based on 3D concolutions.
2
- """
3
- from typing import List, Union
4
-
5
- import torch
6
- from torch import nn
7
- from torchvision.transforms import CenterCrop
8
-
9
- from pvnet.models.multimodal.encoders.basic_blocks import (
10
- AbstractNWPSatelliteEncoder,
11
- ResidualConv3dBlock,
12
- ResidualConv3dBlock2,
13
- )
14
-
15
-
16
- class DefaultPVNet(AbstractNWPSatelliteEncoder):
17
- """This is the original encoding module used in PVNet, with a few minor tweaks."""
18
-
19
- def __init__(
20
- self,
21
- sequence_length: int,
22
- image_size_pixels: int,
23
- in_channels: int,
24
- out_features: int,
25
- number_of_conv3d_layers: int = 4,
26
- conv3d_channels: int = 32,
27
- fc_features: int = 128,
28
- spatial_kernel_size: int = 3,
29
- temporal_kernel_size: int = 3,
30
- padding: Union[int, List[int]] = (1, 0, 0),
31
- ):
32
- """This is the original encoding module used in PVNet, with a few minor tweaks.
33
-
34
- Args:
35
- sequence_length: The time sequence length of the data.
36
- image_size_pixels: The spatial size of the image. Assumed square.
37
- in_channels: Number of input channels.
38
- out_features: Number of output features.
39
- number_of_conv3d_layers: Number of convolution 3d layers that are used.
40
- conv3d_channels: Number of channels used in each conv3d layer.
41
- fc_features: number of output nodes out of the hidden fully connected layer.
42
- spatial_kernel_size: The spatial size of the kernel used in the conv3d layers.
43
- temporal_kernel_size: The temporal size of the kernel used in the conv3d layers.
44
- padding: The padding used in the conv3d layers. If an int, the same padding
45
- is used in all dimensions
46
- """
47
- super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
48
- if isinstance(padding, int):
49
- padding = (padding, padding, padding)
50
- # Check that the output shape of the convolutional layers will be at least 1x1
51
- cnn_spatial_output_size = (
52
- image_size_pixels
53
- - ((spatial_kernel_size - 2 * padding[1]) - 1) * number_of_conv3d_layers
54
- )
55
- cnn_sequence_length = (
56
- sequence_length
57
- - ((temporal_kernel_size - 2 * padding[0]) - 1) * number_of_conv3d_layers
58
- )
59
- if not (cnn_spatial_output_size >= 1):
60
- raise ValueError(
61
- f"cannot use this many conv3d layers ({number_of_conv3d_layers}) with this input "
62
- f"spatial size ({image_size_pixels})"
63
- )
64
-
65
- conv_layers = []
66
-
67
- conv_layers += [
68
- nn.Conv3d(
69
- in_channels=in_channels,
70
- out_channels=conv3d_channels,
71
- kernel_size=(temporal_kernel_size, spatial_kernel_size, spatial_kernel_size),
72
- padding=padding,
73
- ),
74
- nn.ELU(),
75
- ]
76
- for i in range(0, number_of_conv3d_layers - 1):
77
- conv_layers += [
78
- nn.Conv3d(
79
- in_channels=conv3d_channels,
80
- out_channels=conv3d_channels,
81
- kernel_size=(temporal_kernel_size, spatial_kernel_size, spatial_kernel_size),
82
- padding=padding,
83
- ),
84
- nn.ELU(),
85
- ]
86
-
87
- self.conv_layers = nn.Sequential(*conv_layers)
88
-
89
- # Calculate the size of the output of the 3D convolutional layers
90
- cnn_output_size = conv3d_channels * cnn_spatial_output_size**2 * cnn_sequence_length
91
-
92
- self.final_block = nn.Sequential(
93
- nn.Linear(in_features=cnn_output_size, out_features=fc_features),
94
- nn.ELU(),
95
- nn.Linear(in_features=fc_features, out_features=out_features),
96
- nn.ELU(),
97
- )
98
-
99
- def forward(self, x):
100
- """Run model forward"""
101
- out = self.conv_layers(x)
102
- out = out.reshape(x.shape[0], -1)
103
-
104
- # Fully connected layers
105
- out = self.final_block(out)
106
-
107
- return out
108
-
109
-
110
- class DefaultPVNet2(AbstractNWPSatelliteEncoder):
111
- """The original encoding module used in PVNet, with a few minor tweaks, and batchnorm."""
112
-
113
- def __init__(
114
- self,
115
- sequence_length: int,
116
- image_size_pixels: int,
117
- in_channels: int,
118
- out_features: int,
119
- number_of_conv3d_layers: int = 4,
120
- conv3d_channels: int = 32,
121
- fc_features: int = 128,
122
- batch_norm=True,
123
- fc_dropout=0.2,
124
- ):
125
- """The original encoding module used in PVNet, with a few minor tweaks, and batchnorm.
126
-
127
- Args:
128
- sequence_length: The time sequence length of the data.
129
- image_size_pixels: The spatial size of the image. Assumed square.
130
- in_channels: Number of input channels.
131
- out_features: Number of output features.
132
- number_of_conv3d_layers: Number of convolution 3d layers that are used.
133
- conv3d_channels: Number of channels used in each conv3d layer.
134
- fc_features: number of output nodes out of the hidden fully connected layer.
135
- batch_norm: Whether to include 3D batch normalisation.
136
- fc_dropout: Probability of an element to be zeroed before the last two fully connected
137
- layers.
138
- """
139
- super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
140
-
141
- # Check that the output shape of the convolutional layers will be at least 1x1
142
- cnn_spatial_output_size = image_size_pixels - 2 * number_of_conv3d_layers
143
- if not (cnn_spatial_output_size > 0):
144
- raise ValueError(
145
- f"cannot use this many conv3d layers ({number_of_conv3d_layers}) with this input "
146
- f"spatial size ({image_size_pixels})"
147
- )
148
-
149
- conv_layers = [
150
- nn.Conv3d(
151
- in_channels=in_channels,
152
- out_channels=conv3d_channels,
153
- kernel_size=(3, 3, 3),
154
- padding=(1, 0, 0),
155
- ),
156
- nn.LeakyReLU(),
157
- ]
158
- if batch_norm:
159
- # Inserted before activation using position -1
160
- conv_layers.insert(-1, nn.BatchNorm3d(conv3d_channels))
161
- for i in range(0, number_of_conv3d_layers - 1):
162
- conv_layers += [
163
- nn.Conv3d(
164
- in_channels=conv3d_channels,
165
- out_channels=conv3d_channels,
166
- kernel_size=(3, 3, 3),
167
- padding=(1, 0, 0),
168
- ),
169
- nn.LeakyReLU(),
170
- ]
171
- if batch_norm:
172
- # Inserted before activation using position -1
173
- conv_layers.insert(-1, nn.BatchNorm3d(conv3d_channels))
174
-
175
- self.conv_layers = nn.Sequential(*conv_layers)
176
-
177
- # Calculate the size of the output of the 3D convolutional layers
178
- cnn_output_size = conv3d_channels * cnn_spatial_output_size**2 * sequence_length
179
-
180
- final_block = [
181
- nn.Linear(in_features=cnn_output_size, out_features=fc_features),
182
- nn.LeakyReLU(),
183
- nn.Linear(in_features=fc_features, out_features=out_features),
184
- nn.LeakyReLU(),
185
- ]
186
-
187
- if fc_dropout > 0:
188
- # Insert after the linear layers
189
- final_block.insert(1, nn.Dropout(fc_dropout))
190
- final_block.insert(-1, nn.Dropout(fc_dropout))
191
-
192
- self.final_block = nn.Sequential(*final_block)
193
-
194
- def forward(self, x):
195
- """Run model forward"""
196
- out = self.conv_layers(x)
197
- out = out.reshape(x.shape[0], -1)
198
-
199
- # Fully connected layers
200
- out = self.final_block(out)
201
-
202
- return out
203
-
204
-
205
- class ResConv3DNet2(AbstractNWPSatelliteEncoder):
206
- """3D convolutional network based on ResNet architecture.
207
-
208
- The residual blocks are implemented based on the best performing block in [1].
209
-
210
- Sources:
211
- [1] https://arxiv.org/pdf/1603.05027.pdf
212
- """
213
-
214
- def __init__(
215
- self,
216
- sequence_length: int,
217
- image_size_pixels: int,
218
- in_channels: int,
219
- out_features: int,
220
- hidden_channels: int = 32,
221
- n_res_blocks: int = 4,
222
- res_block_layers: int = 2,
223
- batch_norm=True,
224
- dropout_frac=0.0,
225
- ):
226
- """Fully connected deep network based on ResNet architecture.
227
-
228
- Args:
229
- sequence_length: The time sequence length of the data.
230
- image_size_pixels: The spatial size of the image. Assumed square.
231
- in_channels: Number of input channels.
232
- out_features: Number of output features.
233
- hidden_channels: Number of channels in middle hidden layers.
234
- n_res_blocks: Number of residual blocks to use.
235
- res_block_layers: Number of Conv3D layers used in each residual block.
236
- batch_norm: Whether to include batch normalisation.
237
- dropout_frac: Probability of an element to be zeroed in the residual pathways.
238
- """
239
- super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
240
-
241
- model = [
242
- nn.Conv3d(
243
- in_channels=in_channels,
244
- out_channels=hidden_channels,
245
- kernel_size=(3, 3, 3),
246
- padding=(1, 1, 1),
247
- ),
248
- ]
249
-
250
- for i in range(n_res_blocks):
251
- model.extend(
252
- [
253
- ResidualConv3dBlock2(
254
- in_channels=hidden_channels,
255
- n_layers=res_block_layers,
256
- dropout_frac=dropout_frac,
257
- batch_norm=batch_norm,
258
- ),
259
- nn.AvgPool3d((1, 2, 2), stride=(1, 2, 2)),
260
- ]
261
- )
262
-
263
- # Calculate the size of the output of the 3D convolutional layers
264
- final_im_size = image_size_pixels // (2**n_res_blocks)
265
- cnn_output_size = hidden_channels * sequence_length * final_im_size * final_im_size
266
-
267
- model.extend(
268
- [
269
- nn.ELU(),
270
- nn.Flatten(start_dim=1, end_dim=-1),
271
- nn.Linear(in_features=cnn_output_size, out_features=out_features),
272
- nn.ELU(),
273
- ]
274
- )
275
-
276
- self.model = nn.Sequential(*model)
277
-
278
- def forward(self, x):
279
- """Run model forward"""
280
- return self.model(x)
281
-
282
-
283
- class EncoderUNET(AbstractNWPSatelliteEncoder):
284
- """An encoder based on emodifed UNet architecture.
285
-
286
- An encoder for satellite and/or NWP data taking inspiration from the kinds of skip
287
- connections in UNet. This differs from an actual UNet in that it does not have upsampling
288
- layers, instead it concats features from different spatial scales, and applies a few extra
289
- conv3d layers.
290
- """
291
-
292
- def __init__(
293
- self,
294
- sequence_length: int,
295
- image_size_pixels: int,
296
- in_channels: int,
297
- out_features: int,
298
- n_downscale: int = 3,
299
- res_block_layers: int = 2,
300
- conv3d_channels: int = 32,
301
- dropout_frac: float = 0.1,
302
- ):
303
- """An encoder based on emodifed UNet architecture.
304
-
305
- Args:
306
- sequence_length: The time sequence length of the data.
307
- image_size_pixels: The spatial size of the image. Assumed square.
308
- in_channels: Number of input channels.
309
- out_features: Number of output features.
310
- n_downscale: Number of conv3d and spatially downscaling layers that are used.
311
- res_block_layers: Number of residual blocks used after each downscale layer.
312
- conv3d_channels: Number of channels used in each conv3d layer.
313
- dropout_frac: Probability of an element to be zeroed in the residual pathways.
314
- """
315
- cnn_spatial_output = image_size_pixels // (2**n_downscale)
316
-
317
- if not (cnn_spatial_output > 0):
318
- raise ValueError(
319
- f"cannot use this many downscaling layers ({n_downscale}) with this input "
320
- f"spatial size ({image_size_pixels})"
321
- )
322
-
323
- super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
324
-
325
- self.first_layer = nn.Sequential(
326
- nn.Conv3d(
327
- in_channels=in_channels,
328
- out_channels=conv3d_channels,
329
- kernel_size=(1, 1, 1),
330
- padding=(0, 0, 0),
331
- ),
332
- ResidualConv3dBlock(
333
- in_channels=conv3d_channels,
334
- n_layers=res_block_layers,
335
- dropout_frac=dropout_frac,
336
- ),
337
- )
338
-
339
- downscale_layers = []
340
- for _ in range(n_downscale):
341
- downscale_layers += [
342
- nn.Sequential(
343
- ResidualConv3dBlock(
344
- in_channels=conv3d_channels,
345
- n_layers=res_block_layers,
346
- dropout_frac=dropout_frac,
347
- ),
348
- nn.ELU(),
349
- nn.Conv3d(
350
- in_channels=conv3d_channels,
351
- out_channels=conv3d_channels,
352
- kernel_size=(1, 2, 2),
353
- padding=(0, 0, 0),
354
- stride=(1, 2, 2),
355
- ),
356
- )
357
- ]
358
-
359
- self.downscale_layers = nn.ModuleList(downscale_layers)
360
-
361
- self.crop_fn = CenterCrop(cnn_spatial_output)
362
-
363
- cat_channels = conv3d_channels * (1 + n_downscale)
364
- self.post_cat_conv = nn.Sequential(
365
- ResidualConv3dBlock(
366
- in_channels=cat_channels,
367
- n_layers=res_block_layers,
368
- ),
369
- nn.ELU(),
370
- nn.Conv3d(
371
- in_channels=cat_channels,
372
- out_channels=conv3d_channels,
373
- kernel_size=(1, 1, 1),
374
- ),
375
- )
376
-
377
- final_channels = (
378
- (image_size_pixels // (2**n_downscale)) ** 2 * conv3d_channels * sequence_length
379
- )
380
- self.final_layer = nn.Sequential(
381
- nn.ELU(),
382
- nn.Linear(
383
- in_features=final_channels,
384
- out_features=out_features,
385
- ),
386
- nn.ELU(),
387
- )
388
-
389
- def forward(self, x):
390
- """Run model forward"""
391
- out = self.first_layer(x)
392
- outputs = [self.crop_fn(out)]
393
-
394
- for layer in self.downscale_layers:
395
- out = layer(out)
396
- outputs += [self.crop_fn(out)]
397
-
398
- out = torch.cat(outputs, dim=1)
399
- out = self.post_cat_conv(out)
400
- out = torch.flatten(out, start_dim=1)
401
- out = self.final_layer(out)
402
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/multimodal/encoders/encodersRNN.py DELETED
@@ -1,141 +0,0 @@
1
- """Encoder modules for the satellite/NWP data based on recursive and 2D convolutional layers.
2
- """
3
-
4
- import torch
5
- from torch import nn
6
-
7
- from pvnet.models.multimodal.encoders.basic_blocks import (
8
- AbstractNWPSatelliteEncoder,
9
- ImageSequenceEncoder,
10
- )
11
-
12
-
13
- class ConvLSTM(AbstractNWPSatelliteEncoder):
14
- """Convolutional LSTM block from MetNet."""
15
-
16
- def __init__(
17
- self,
18
- sequence_length: int,
19
- image_size_pixels: int,
20
- in_channels: int,
21
- out_features: int,
22
- hidden_channels: int = 32,
23
- num_layers: int = 2,
24
- kernel_size: int = 3,
25
- bias: bool = True,
26
- activation=torch.tanh,
27
- batchnorm=False,
28
- ):
29
- """Convolutional LSTM block from MetNet.
30
-
31
- Args:
32
- sequence_length: The time sequence length of the data.
33
- image_size_pixels: The spatial size of the image. Assumed square.
34
- in_channels: Number of input channels.
35
- out_features: Number of output features.
36
- hidden_channels: Hidden dimension size.
37
- num_layers: Depth of ConvLSTM cells.
38
- kernel_size: Kernel size.
39
- bias: Whether to add bias.
40
- activation: Activation function for ConvLSTM cells.
41
- batchnorm: Whether to use batch norm.
42
- """
43
- from metnet.layers.ConvLSTM import ConvLSTM as _ConvLSTM
44
-
45
- super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
46
-
47
- self.conv_lstm = _ConvLSTM(
48
- input_dim=in_channels,
49
- hidden_dim=hidden_channels,
50
- kernel_size=kernel_size,
51
- num_layers=num_layers,
52
- bias=bias,
53
- activation=activation,
54
- batchnorm=batchnorm,
55
- )
56
-
57
- # Calculate the size of the output of the ConvLSTM network
58
- convlstm_output_size = hidden_channels * image_size_pixels**2
59
-
60
- self.final_block = nn.Sequential(
61
- nn.Linear(in_features=convlstm_output_size, out_features=out_features),
62
- nn.ELU(),
63
- )
64
-
65
- def forward(self, x):
66
- """Run model forward"""
67
-
68
- batch_size, channel, seq_len, height, width = x.shape
69
- x = torch.swapaxes(x, 1, 2)
70
-
71
- res, _ = self.conv_lstm(x)
72
-
73
- # Select last state only
74
- out = res[:, -1]
75
-
76
- # Flatten and fully connected layer
77
- out = out.reshape(batch_size, -1)
78
- out = self.final_block(out)
79
-
80
- return out
81
-
82
-
83
- class FlattenLSTM(AbstractNWPSatelliteEncoder):
84
- """Convolutional blocks followed by LSTM."""
85
-
86
- def __init__(
87
- self,
88
- sequence_length: int,
89
- image_size_pixels: int,
90
- in_channels: int,
91
- out_features: int,
92
- num_layers: int = 2,
93
- number_of_conv2d_layers: int = 4,
94
- conv2d_channels: int = 32,
95
- ):
96
- """Network consisting of 2D spatial convolutional and LSTM sequence encoder.
97
-
98
- Args:
99
- sequence_length: The time sequence length of the data.
100
- image_size_pixels: The spatial size of the image. Assumed square.
101
- in_channels: Number of input channels.
102
- out_features: Number of output features. Also used for LSTM hidden dimension.
103
- num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking
104
- two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of
105
- the first LSTM and computing the final results.
106
- number_of_conv2d_layers: Number of convolution 2D layers that are used.
107
- conv2d_channels: Number of channels used in each conv2d layer.
108
- """
109
-
110
- super().__init__(sequence_length, image_size_pixels, in_channels, out_features)
111
-
112
- self.lstm = nn.LSTM(
113
- input_size=out_features,
114
- hidden_size=out_features,
115
- num_layers=num_layers,
116
- batch_first=True,
117
- )
118
-
119
- self.encode_image_sequence = ImageSequenceEncoder(
120
- image_size_pixels=image_size_pixels,
121
- in_channels=in_channels,
122
- number_of_conv2d_layers=number_of_conv2d_layers,
123
- conv2d_channels=conv2d_channels,
124
- fc_features=out_features,
125
- )
126
-
127
- self.final_block = nn.Sequential(
128
- nn.Linear(in_features=out_features, out_features=out_features),
129
- nn.ELU(),
130
- )
131
-
132
- def forward(self, x):
133
- """Run model forward"""
134
- encoded_images = self.encode_image_sequence(x)
135
-
136
- _, (_, c_n) = self.lstm(encoded_images)
137
-
138
- # Take only the deepest level hidden cell state
139
- out = self.final_block(c_n[-1])
140
-
141
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/multimodal/linear_networks/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Submodels to combine 1D feature vectors from different sources and make final predictions"""
 
 
pvnet/models/multimodal/linear_networks/basic_blocks.py DELETED
@@ -1,121 +0,0 @@
1
- """Basic blocks for the lienar networks"""
2
- from abc import ABCMeta, abstractmethod
3
- from collections import OrderedDict
4
-
5
- import torch
6
- from torch import nn
7
-
8
-
9
- class AbstractLinearNetwork(nn.Module, metaclass=ABCMeta):
10
- """Abstract class for a network to combine the features from all the inputs."""
11
-
12
- def __init__(
13
- self,
14
- in_features: int,
15
- out_features: int,
16
- ):
17
- """Abstract class for a network to combine the features from all the inputs.
18
-
19
- Args:
20
- in_features: Number of input features.
21
- out_features: Number of output features.
22
- """
23
- super().__init__()
24
-
25
- def cat_modes(self, x):
26
- """Concatenate modes of input data into 1D feature vector"""
27
- if isinstance(x, OrderedDict):
28
- return torch.cat([value for key, value in x.items()], dim=1)
29
- elif isinstance(x, torch.Tensor):
30
- return x
31
- else:
32
- raise ValueError(f"Input of unexpected type {type(x)}")
33
-
34
- @abstractmethod
35
- def forward(self):
36
- """Run model forward"""
37
- pass
38
-
39
-
40
- class ResidualLinearBlock(nn.Module):
41
- """A 1D fully-connected residual block using ELU activations and including optional dropout."""
42
-
43
- def __init__(
44
- self,
45
- in_features: int,
46
- n_layers: int = 2,
47
- dropout_frac: float = 0.0,
48
- ):
49
- """A 1D fully-connected residual block using ELU activations and including optional dropout.
50
-
51
- Args:
52
- in_features: Number of input features.
53
- n_layers: Number of layers in residual pathway.
54
- dropout_frac: Probability of an element to be zeroed.
55
- """
56
- super().__init__()
57
-
58
- layers = []
59
- for i in range(n_layers):
60
- layers += [
61
- nn.ELU(),
62
- nn.Linear(
63
- in_features=in_features,
64
- out_features=in_features,
65
- ),
66
- nn.Dropout(p=dropout_frac),
67
- ]
68
- self.model = nn.Sequential(*layers)
69
-
70
- def forward(self, x):
71
- """Run model forward"""
72
- return self.model(x) + x
73
-
74
-
75
- class ResidualLinearBlock2(nn.Module):
76
- """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1].
77
-
78
- This was the best performing residual block tested in the study. This implementation differs
79
- from that block just by using LeakyReLU activation to avoid dead neuron, and by including
80
- optional dropout in the residual branch. This is also a 1D fully connected layer residual block
81
- rather than a 2D convolutional block.
82
-
83
- Sources:
84
- [1] https://arxiv.org/pdf/1603.05027.pdf
85
- """
86
-
87
- def __init__(
88
- self,
89
- in_features: int,
90
- n_layers: int = 2,
91
- dropout_frac: float = 0.0,
92
- ):
93
- """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1].
94
-
95
- Sources:
96
- [1] https://arxiv.org/pdf/1603.05027.pdf
97
-
98
- Args:
99
- in_features: Number of input features.
100
- n_layers: Number of layers in residual pathway.
101
- dropout_frac: Probability of an element to be zeroed.
102
- """
103
- super().__init__()
104
-
105
- layers = []
106
- for i in range(n_layers):
107
- layers += [
108
- nn.BatchNorm1d(in_features),
109
- nn.Dropout(p=dropout_frac),
110
- nn.LeakyReLU(),
111
- nn.Linear(
112
- in_features=in_features,
113
- out_features=in_features,
114
- ),
115
- ]
116
-
117
- self.model = nn.Sequential(*layers)
118
-
119
- def forward(self, x):
120
- """Run model forward"""
121
- return self.model(x) + x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/multimodal/linear_networks/networks.py DELETED
@@ -1,332 +0,0 @@
1
- """Linear networks used for the fusion model"""
2
- from torch import nn, rand
3
-
4
- from pvnet.models.multimodal.linear_networks.basic_blocks import (
5
- AbstractLinearNetwork,
6
- ResidualLinearBlock,
7
- ResidualLinearBlock2,
8
- )
9
-
10
-
11
- class DefaultFCNet(AbstractLinearNetwork):
12
- """Similar to the original FCNet module used in PVNet, with a few minor tweaks.
13
-
14
- This is a 2-layer fully connected block, with internal ELU activations and output ReLU.
15
- """
16
-
17
- def __init__(
18
- self,
19
- in_features: int,
20
- out_features: int,
21
- fc_hidden_features: int = 128,
22
- ):
23
- """Similar to the original FCNet module used in PVNet, with a few minor tweaks.
24
-
25
- Args:
26
- in_features: Number of input features.
27
- out_features: Number of output features.
28
- fc_hidden_features: Number of features in middle hidden layer.
29
- """
30
- super().__init__(in_features, out_features)
31
-
32
- self.model = nn.Sequential(
33
- nn.Linear(in_features=in_features, out_features=fc_hidden_features),
34
- nn.ELU(),
35
- nn.Linear(in_features=fc_hidden_features, out_features=out_features),
36
- nn.ReLU(),
37
- )
38
-
39
- def forward(self, x):
40
- """Run model forward"""
41
- x = self.cat_modes(x)
42
- return self.model(x)
43
-
44
-
45
- class ResFCNet(AbstractLinearNetwork):
46
- """Fully-connected deep network based on ResNet architecture.
47
-
48
- Internally, this network uses ELU activations throughout the residual blocks.
49
- With n_res_blocks=0 this becomes equivalent to `DefaultFCNet`.
50
- """
51
-
52
- def __init__(
53
- self,
54
- in_features: int,
55
- out_features: int,
56
- fc_hidden_features: int = 128,
57
- n_res_blocks: int = 4,
58
- res_block_layers: int = 2,
59
- dropout_frac: float = 0.2,
60
- ):
61
- """Fully-connected deep network based on ResNet architecture.
62
-
63
- Args:
64
- in_features: Number of input features.
65
- out_features: Number of output features.
66
- fc_hidden_features: Number of features in middle hidden layers.
67
- n_res_blocks: Number of residual blocks to use.
68
- res_block_layers: Number of fully-connected layers used in each residual block.
69
- dropout_frac: Probability of an element to be zeroed in the residual pathways.
70
- """
71
- super().__init__(in_features, out_features)
72
-
73
- model = [
74
- nn.Linear(in_features=in_features, out_features=fc_hidden_features),
75
- ]
76
-
77
- for i in range(n_res_blocks):
78
- model += [
79
- ResidualLinearBlock(
80
- in_features=fc_hidden_features,
81
- n_layers=res_block_layers,
82
- dropout_frac=dropout_frac,
83
- )
84
- ]
85
-
86
- model += [
87
- nn.ELU(),
88
- nn.Linear(in_features=fc_hidden_features, out_features=out_features),
89
- nn.LeakyReLU(negative_slope=0.01),
90
- ]
91
- self.model = nn.Sequential(*model)
92
-
93
- def forward(self, x):
94
- """Run model forward"""
95
- x = self.cat_modes(x)
96
- return self.model(x)
97
-
98
-
99
- class ResFCNet2(AbstractLinearNetwork):
100
- """Fully connected deep network based on ResNet architecture.
101
-
102
- This architecture is similar to
103
- `ResFCNet`, except that it uses LeakyReLU activations internally, and batchnorm in the residual
104
- branches. The residual blocks are implemented based on the best performing block in [1].
105
-
106
- Sources:
107
- [1] https://arxiv.org/pdf/1603.05027.pdf
108
- """
109
-
110
- def __init__(
111
- self,
112
- in_features: int,
113
- out_features: int,
114
- fc_hidden_features: int = 128,
115
- n_res_blocks: int = 4,
116
- res_block_layers: int = 2,
117
- dropout_frac=0.0,
118
- ):
119
- """Fully connected deep network based on ResNet architecture.
120
-
121
- Args:
122
- in_features: Number of input features.
123
- out_features: Number of output features.
124
- fc_hidden_features: Number of features in middle hidden layers.
125
- n_res_blocks: Number of residual blocks to use.
126
- res_block_layers: Number of fully-connected layers used in each residual block.
127
- dropout_frac: Probability of an element to be zeroed in the residual pathways.
128
- """
129
- super().__init__(in_features, out_features)
130
-
131
- model = [
132
- nn.Linear(in_features=in_features, out_features=fc_hidden_features),
133
- ]
134
-
135
- for i in range(n_res_blocks):
136
- model += [
137
- ResidualLinearBlock2(
138
- in_features=fc_hidden_features,
139
- n_layers=res_block_layers,
140
- dropout_frac=dropout_frac,
141
- )
142
- ]
143
-
144
- model += [
145
- nn.LeakyReLU(),
146
- nn.Linear(in_features=fc_hidden_features, out_features=out_features),
147
- nn.LeakyReLU(negative_slope=0.01),
148
- ]
149
-
150
- self.model = nn.Sequential(*model)
151
-
152
- def forward(self, x):
153
- """Run model forward"""
154
- x = self.cat_modes(x)
155
- return self.model(x)
156
-
157
-
158
- class SNN(AbstractLinearNetwork):
159
- """Self normalising neural network implementation borrowed from [1] and proposed in [2].
160
-
161
- Sources:
162
- [1] https://github.com/tonyduan/snn/blob/master/snn/models.py
163
- [2] https://arxiv.org/pdf/1706.02515v5.pdf
164
-
165
- Args:
166
- in_features: Number of input features.
167
- out_features: Number of output features.
168
- fc_hidden_features: Number of features in middle hidden layers.
169
- n_layers: Number of fully-connected layers used in the network.
170
- dropout_frac: Probability of an element to be zeroed.
171
-
172
- """
173
-
174
- def __init__(
175
- self,
176
- in_features: int,
177
- out_features: int,
178
- fc_hidden_features: int = 128,
179
- n_layers: int = 10,
180
- dropout_frac: float = 0.0,
181
- ):
182
- """Self normalising neural network implementation borrowed from [1] and proposed in [2].
183
-
184
- Sources:
185
- [1] https://github.com/tonyduan/snn/blob/master/snn/models.py
186
- [2] https://arxiv.org/pdf/1706.02515v5.pdf
187
-
188
- Args:
189
- in_features: Number of input features.
190
- out_features: Number of output features.
191
- fc_hidden_features: Number of features in middle hidden layers.
192
- n_layers: Number of fully-connected layers used in the network.
193
- dropout_frac: Probability of an element to be zeroed.
194
-
195
- """
196
- super().__init__(in_features, out_features)
197
-
198
- layers = [
199
- nn.Linear(in_features, fc_hidden_features, bias=False),
200
- nn.SELU(),
201
- nn.AlphaDropout(p=dropout_frac),
202
- ]
203
- for i in range(1, n_layers - 1):
204
- layers += [
205
- nn.Linear(fc_hidden_features, fc_hidden_features, bias=False),
206
- nn.SELU(),
207
- nn.AlphaDropout(p=dropout_frac),
208
- ]
209
- layers += [
210
- nn.Linear(fc_hidden_features, out_features, bias=True),
211
- nn.LeakyReLU(negative_slope=0.01),
212
- ]
213
-
214
- self.network = nn.Sequential(*layers)
215
- self._reset_parameters()
216
-
217
- def forward(self, x):
218
- """Run model forward"""
219
- x = self.cat_modes(x)
220
- return self.network(x)
221
-
222
- def _reset_parameters(self):
223
- for layer in self.network:
224
- if isinstance(layer, nn.Linear):
225
- nn.init.normal_(layer.weight, std=layer.out_features**-0.5)
226
- if layer.bias is not None:
227
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(layer.weight)
228
- bound = fan_in**-0.5
229
- nn.init.uniform_(layer.bias, -bound, bound)
230
-
231
-
232
- class TabNet(AbstractLinearNetwork):
233
- """An implmentation of TabNet [1].
234
-
235
- The implementation comes rom `pytorch_tabnet` and this must be installed for use.
236
-
237
-
238
- Sources:
239
- [1] https://arxiv.org/abs/1908.07442
240
- """
241
-
242
- def __init__(
243
- self,
244
- in_features: int,
245
- out_features: int,
246
- n_d=8,
247
- n_a=8,
248
- n_steps=3,
249
- gamma=1.3,
250
- cat_idxs=[],
251
- cat_dims=[],
252
- cat_emb_dim=1,
253
- n_independent=2,
254
- n_shared=2,
255
- epsilon=1e-15,
256
- virtual_batch_size=128,
257
- momentum=0.02,
258
- mask_type="sparsemax",
259
- ):
260
- """An implmentation of TabNet [1].
261
-
262
- Sources:
263
- [1] https://arxiv.org/abs/1908.07442
264
-
265
- Args:
266
- in_features: int
267
- Number of input features.
268
- out_features: int
269
- Number of output features.
270
- n_d : int
271
- Dimension of the prediction layer (usually between 4 and 64)
272
- n_a : int
273
- Dimension of the attention layer (usually between 4 and 64)
274
- n_steps : int
275
- Number of successive steps in the network (usually between 3 and 10)
276
- gamma : float
277
- Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0)
278
- cat_idxs : list of int
279
- Index of each categorical column in the dataset
280
- cat_dims : list of int
281
- Number of categories in each categorical column
282
- cat_emb_dim : int or list of int
283
- Size of the embedding of categorical features
284
- if int, all categorical features will have same embedding size
285
- if list of int, every corresponding feature will have specific size
286
- n_independent : int
287
- Number of independent GLU layer in each GLU block (default 2)
288
- n_shared : int
289
- Number of independent GLU layer in each GLU block (default 2)
290
- epsilon : float
291
- Avoid log(0), this should be kept very low
292
- virtual_batch_size : int
293
- Batch size for Ghost Batch Normalization
294
- momentum : float
295
- Float value between 0 and 1 which will be used for momentum in all batch norm
296
- mask_type : str
297
- Either "sparsemax" or "entmax" : this is the masking function to use
298
- """
299
- from pytorch_tabnet.tab_network import TabNet as _TabNetModel
300
-
301
- super().__init__(in_features, out_features)
302
-
303
- self._tabnet = _TabNetModel(
304
- input_dim=in_features,
305
- output_dim=out_features,
306
- n_d=n_d,
307
- n_a=n_a,
308
- n_steps=n_steps,
309
- gamma=gamma,
310
- cat_idxs=cat_idxs,
311
- cat_dims=cat_dims,
312
- cat_emb_dim=cat_emb_dim,
313
- n_independent=n_independent,
314
- n_shared=n_shared,
315
- epsilon=epsilon,
316
- virtual_batch_size=virtual_batch_size,
317
- momentum=momentum,
318
- mask_type=mask_type,
319
- group_attention_matrix=rand(4, in_features),
320
- )
321
-
322
- self.activation = nn.LeakyReLU(negative_slope=0.01)
323
-
324
- def forward(self, x):
325
- """Run model forward"""
326
- # TODO: USE THIS LOSS COMPONENT
327
- # loss = self.compute_loss(output, y)
328
- # Add the overall sparsity loss
329
- # loss = loss - self.lambda_sparse * M_loss
330
- x = self.cat_modes(x)
331
- out1, M_loss = self._tabnet(x)
332
- return self.activation(out1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/multimodal/multimodal.py DELETED
@@ -1,417 +0,0 @@
1
- """The default composite model architecture for PVNet"""
2
-
3
- import logging
4
- from collections import OrderedDict
5
- from typing import Any, Optional
6
-
7
- import torch
8
- from omegaconf import DictConfig
9
- from torch import nn
10
-
11
- import pvnet
12
- from pvnet.models.base_model import BaseModel
13
- from pvnet.models.multimodal.basic_blocks import ImageEmbedding
14
- from pvnet.models.multimodal.encoders.basic_blocks import AbstractNWPSatelliteEncoder
15
- from pvnet.models.multimodal.linear_networks.basic_blocks import AbstractLinearNetwork
16
- from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractSitesEncoder
17
- from pvnet.optimizers import AbstractOptimizer
18
-
19
- logger = logging.getLogger(__name__)
20
-
21
-
22
- class Model(BaseModel):
23
- """Neural network which combines information from different sources
24
-
25
- Architecture is roughly as follows:
26
-
27
- - Satellite data, if included, is put through an encoder which transforms it from 4D, with time,
28
- channel, height, and width dimensions to become a 1D feature vector.
29
- - NWP, if included, is put through a similar encoder.
30
- - PV site-level data, if included, is put through an encoder which transforms it from 2D, with
31
- time and system-ID dimensions, to become a 1D feature vector.
32
- - The satellite features*, NWP features*, PV site-level features*, GSP ID embedding*, and sun
33
- paramters* are concatenated into a 1D feature vector and passed through another neural
34
- network to combine them and produce a forecast.
35
-
36
- * if included
37
- """
38
-
39
- name = "conv3d_sat_nwp"
40
-
41
- def __init__(
42
- self,
43
- output_network: AbstractLinearNetwork,
44
- output_quantiles: Optional[list[float]] = None,
45
- nwp_encoders_dict: Optional[dict[AbstractNWPSatelliteEncoder]] = None,
46
- sat_encoder: Optional[AbstractNWPSatelliteEncoder] = None,
47
- pv_encoder: Optional[AbstractSitesEncoder] = None,
48
- sensor_encoder: Optional[AbstractSitesEncoder] = None,
49
- add_image_embedding_channel: bool = False,
50
- include_gsp_yield_history: bool = True,
51
- include_site_yield_history: Optional[bool] = False,
52
- include_sun: bool = True,
53
- include_time: bool = False,
54
- location_id_mapping: Optional[dict[Any, int]] = None,
55
- embedding_dim: Optional[int] = 16,
56
- forecast_minutes: int = 30,
57
- history_minutes: int = 60,
58
- sat_history_minutes: Optional[int] = None,
59
- min_sat_delay_minutes: Optional[int] = 30,
60
- nwp_forecast_minutes: Optional[DictConfig] = None,
61
- nwp_history_minutes: Optional[DictConfig] = None,
62
- pv_history_minutes: Optional[int] = None,
63
- sensor_history_minutes: Optional[int] = None,
64
- sensor_forecast_minutes: Optional[int] = None,
65
- optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
66
- target_key: str = "gsp",
67
- interval_minutes: int = 30,
68
- nwp_interval_minutes: Optional[DictConfig] = None,
69
- pv_interval_minutes: int = 5,
70
- sat_interval_minutes: int = 5,
71
- sensor_interval_minutes: int = 30,
72
- timestep_intervals_to_plot: Optional[list[int]] = None,
73
- adapt_batches: Optional[bool] = False,
74
- forecast_minutes_ignore: Optional[int] = 0,
75
- save_validation_results_csv: Optional[bool] = False,
76
- ):
77
- """Neural network which combines information from different sources.
78
-
79
- Notes:
80
- In the args, where it says a module `m` is partially instantiated, it means that a
81
- normal pytorch module will be returned by running `mod = m(**kwargs)`. In this library,
82
- this partial instantiation is generally achieved using partial instantiation via hydra.
83
- However, the arg is still valid as long as `m(**kwargs)` returns a valid pytorch module
84
- - for example if `m` is a regular function.
85
-
86
- Args:
87
- output_network: A partially instantiated pytorch Module class used to combine the 1D
88
- features to produce the forecast.
89
- output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
90
- None the output is a single value.
91
- nwp_encoders_dict: A dictionary of partially instantiated pytorch Module class used to
92
- encode the NWP data from 4D into a 1D feature vector from different sources.
93
- sat_encoder: A partially instantiated pytorch Module class used to encode the satellite
94
- data from 4D into a 1D feature vector.
95
- pv_encoder: A partially instantiated pytorch Module class used to encode the site-level
96
- PV data from 2D into a 1D feature vector.
97
- add_image_embedding_channel: Add a channel to the NWP and satellite data with the
98
- embedding of the GSP ID.
99
- include_gsp_yield_history: Include GSP yield data.
100
- include_site_yield_history: Include Site yield data.
101
- include_sun: Include sun azimuth and altitude data.
102
- include_time: Include sine and cosine of dates and times.
103
- location_id_mapping: A dictionary mapping the location ID to an integer. ID embedding is
104
- not used if this is not provided.
105
- embedding_dim: Number of embedding dimensions to use for GSP ID.
106
- forecast_minutes: The amount of minutes that should be forecasted.
107
- history_minutes: The default amount of historical minutes that are used.
108
- sat_history_minutes: Length of recent observations used for satellite inputs. Defaults
109
- to `history_minutes` if not provided.
110
- min_sat_delay_minutes: Minimum delay with respect to t0 of the latest available
111
- satellite image.
112
- nwp_forecast_minutes: Period of future NWP forecast data used as input. Defaults to
113
- `forecast_minutes` if not provided.
114
- nwp_history_minutes: Period of historical NWP forecast used as input. Defaults to
115
- `history_minutes` if not provided.
116
- pv_history_minutes: Length of recent site-level PV data used as
117
- input. Defaults to `history_minutes` if not provided.
118
- optimizer: Optimizer factory function used for network.
119
- target_key: The key of the target variable in the batch.
120
- interval_minutes: The interval between each sample of the target data
121
- nwp_interval_minutes: Dictionary of the intervals between each sample of the NWP
122
- data for each source
123
- pv_interval_minutes: The interval between each sample of the PV data
124
- sat_interval_minutes: The interval between each sample of the satellite data
125
- sensor_interval_minutes: The interval between each sample of the sensor data
126
- timestep_intervals_to_plot: Intervals, in timesteps, to plot in
127
- addition to the full forecast
128
- sensor_encoder: Encoder for sensor data
129
- sensor_history_minutes: Length of recent sensor data used as input.
130
- sensor_forecast_minutes: Length of forecast sensor data used as input.
131
- adapt_batches: If set to true, we attempt to slice the batches to the expected shape for
132
- the model to use. This allows us to overprepare batches and slice from them for the
133
- data we need for a model run.
134
- forecast_minutes_ignore: Number of forecast minutes to ignore when calculating losses.
135
- For example if set to 60, the model doesnt predict the first 60 minutes
136
- save_validation_results_csv: whether to save full csv outputs from validation results.
137
- """
138
-
139
- self.include_gsp_yield_history = include_gsp_yield_history
140
- self.include_site_yield_history = include_site_yield_history
141
- self.include_sat = sat_encoder is not None
142
- self.include_nwp = nwp_encoders_dict is not None and len(nwp_encoders_dict) != 0
143
- self.include_pv = pv_encoder is not None
144
- self.include_sun = include_sun
145
- self.include_time = include_time
146
- self.include_sensor = sensor_encoder is not None
147
- self.location_id_mapping = location_id_mapping
148
- self.embedding_dim = embedding_dim
149
- self.add_image_embedding_channel = add_image_embedding_channel
150
- self.interval_minutes = interval_minutes
151
- self.min_sat_delay_minutes = min_sat_delay_minutes
152
- self.adapt_batches = adapt_batches
153
-
154
- if self.location_id_mapping is None:
155
- logger.warning("location_id_mapping` is not provided, "
156
- "defaulting to outdated GSP mapping (0 to 317)")
157
-
158
- # Note 318 is the 2024 UK GSP count, so this is a temporary fix
159
- # for models trained with this default embedding
160
- self.location_id_mapping = {i: i for i in range(318)}
161
-
162
- # in the future location_id_mapping could be None,
163
- # and in this case use_id_embedding should be False
164
- self.use_id_embedding = self.embedding_dim is not None
165
-
166
- if self.use_id_embedding:
167
- num_embeddings = max(self.location_id_mapping.values()) + 1
168
-
169
- super().__init__(
170
- history_minutes=history_minutes,
171
- forecast_minutes=forecast_minutes,
172
- optimizer=optimizer,
173
- output_quantiles=output_quantiles,
174
- target_key=target_key,
175
- interval_minutes=interval_minutes,
176
- timestep_intervals_to_plot=timestep_intervals_to_plot,
177
- forecast_minutes_ignore=forecast_minutes_ignore,
178
- save_validation_results_csv=save_validation_results_csv
179
- )
180
-
181
- # Number of features expected by the output_network
182
- # Add to this as network pieces are constructed
183
- fusion_input_features = 0
184
-
185
- if self.include_sat:
186
- # Param checks
187
- assert sat_history_minutes is not None
188
-
189
- self.sat_sequence_len = (
190
- sat_history_minutes - min_sat_delay_minutes
191
- ) // sat_interval_minutes + 1
192
-
193
- self.sat_encoder = sat_encoder(
194
- sequence_length=self.sat_sequence_len,
195
- in_channels=sat_encoder.keywords["in_channels"] + add_image_embedding_channel,
196
- )
197
- if add_image_embedding_channel:
198
- self.sat_embed = ImageEmbedding(
199
- num_embeddings, self.sat_sequence_len, self.sat_encoder.image_size_pixels
200
- )
201
-
202
- # Update num features
203
- fusion_input_features += self.sat_encoder.out_features
204
-
205
- if self.include_nwp:
206
- # Param checks
207
- assert nwp_forecast_minutes is not None
208
- assert nwp_history_minutes is not None
209
-
210
- # For each NWP encoder the forecast and history minutes must be set
211
- assert set(nwp_encoders_dict.keys()) == set(nwp_forecast_minutes.keys())
212
- assert set(nwp_encoders_dict.keys()) == set(nwp_history_minutes.keys())
213
-
214
- if nwp_interval_minutes is None:
215
- nwp_interval_minutes = dict.fromkeys(nwp_encoders_dict.keys(), 60)
216
-
217
- self.nwp_encoders_dict = torch.nn.ModuleDict()
218
- if add_image_embedding_channel:
219
- self.nwp_embed_dict = torch.nn.ModuleDict()
220
-
221
- for nwp_source in nwp_encoders_dict.keys():
222
- nwp_sequence_len = (
223
- nwp_history_minutes[nwp_source] // nwp_interval_minutes[nwp_source]
224
- + nwp_forecast_minutes[nwp_source] // nwp_interval_minutes[nwp_source]
225
- + 1
226
- )
227
-
228
- self.nwp_encoders_dict[nwp_source] = nwp_encoders_dict[nwp_source](
229
- sequence_length=nwp_sequence_len,
230
- in_channels=(
231
- nwp_encoders_dict[nwp_source].keywords["in_channels"]
232
- + add_image_embedding_channel
233
- ),
234
- )
235
- if add_image_embedding_channel:
236
- self.nwp_embed_dict[nwp_source] = ImageEmbedding(
237
- num_embeddings,
238
- nwp_sequence_len,
239
- self.nwp_encoders_dict[nwp_source].image_size_pixels,
240
- )
241
-
242
- # Update num features
243
- fusion_input_features += self.nwp_encoders_dict[nwp_source].out_features
244
-
245
- if self.include_pv:
246
- assert pv_history_minutes is not None
247
-
248
- self.pv_encoder = pv_encoder(
249
- sequence_length=pv_history_minutes // pv_interval_minutes + 1,
250
- target_key_to_use=self._target_key,
251
- input_key_to_use="site",
252
- )
253
-
254
- # Update num features
255
- fusion_input_features += self.pv_encoder.out_features
256
-
257
- if self.include_sensor:
258
- if sensor_history_minutes is None:
259
- sensor_history_minutes = history_minutes
260
- if sensor_forecast_minutes is None:
261
- sensor_forecast_minutes = forecast_minutes
262
-
263
- self.sensor_encoder = sensor_encoder(
264
- sequence_length=sensor_history_minutes // sensor_interval_minutes
265
- + sensor_forecast_minutes // sensor_interval_minutes
266
- + 1,
267
- target_key_to_use=self._target_key,
268
- input_key_to_use="sensor",
269
- )
270
-
271
- # Update num features
272
- fusion_input_features += self.sensor_encoder.out_features
273
-
274
- if self.use_id_embedding:
275
- self.embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
276
-
277
- # Update num features
278
- fusion_input_features += embedding_dim
279
-
280
- if self.include_sun:
281
- self.sun_fc1 = nn.Linear(
282
- in_features=2
283
- * (self.forecast_len + self.forecast_len_ignore + self.history_len + 1),
284
- out_features=16,
285
- )
286
-
287
- # Update num features
288
- fusion_input_features += 16
289
-
290
- if self.include_time:
291
- self.time_fc1 = nn.Linear(
292
- in_features=4
293
- * (self.forecast_len + self.forecast_len_ignore + self.history_len + 1),
294
- out_features=32,
295
- )
296
-
297
- # Update num features
298
- fusion_input_features += 32
299
-
300
- if include_gsp_yield_history:
301
- # Update num features
302
- fusion_input_features += self.history_len
303
-
304
- if include_site_yield_history:
305
- # Update num features
306
- fusion_input_features += self.history_len + 1
307
-
308
- self.output_network = output_network(
309
- in_features=fusion_input_features,
310
- out_features=self.num_output_features,
311
- )
312
-
313
- self.save_hyperparameters()
314
-
315
- def forward(self, x):
316
- """Run model forward"""
317
-
318
- if self.adapt_batches:
319
- x = self._adapt_batch(x)
320
-
321
- if self.use_id_embedding:
322
- # eg: x['gsp_id] = [1] with location_id_mapping = {1:0}, would give [0]
323
- id = torch.tensor(
324
- [self.location_id_mapping[i.item()] for i in x[f"{self._target_key}_id"]],
325
- device=self.device,
326
- dtype=torch.int64,
327
- )
328
-
329
- modes = OrderedDict()
330
- # ******************* Satellite imagery *************************
331
- if self.include_sat:
332
- # Shape: batch_size, seq_length, channel, height, width
333
- sat_data = x["satellite_actual"][:, : self.sat_sequence_len]
334
- sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels
335
-
336
- if self.add_image_embedding_channel:
337
- sat_data = self.sat_embed(sat_data, id)
338
- modes["sat"] = self.sat_encoder(sat_data)
339
-
340
- # *********************** NWP Data ************************************
341
- if self.include_nwp:
342
- # Loop through potentially many NMPs
343
- for nwp_source in self.nwp_encoders_dict:
344
- # shape: batch_size, seq_len, n_chans, height, width
345
- nwp_data = x["nwp"][nwp_source]["nwp"].float()
346
- nwp_data = torch.swapaxes(nwp_data, 1, 2) # switch time and channels
347
- # Some NWP variables can overflow into NaNs when normalised if they have extreme
348
- # tails
349
- nwp_data = torch.clip(nwp_data, min=-50, max=50)
350
-
351
- if self.add_image_embedding_channel:
352
- nwp_data = self.nwp_embed_dict[nwp_source](nwp_data, id)
353
-
354
- nwp_out = self.nwp_encoders_dict[nwp_source](nwp_data)
355
- modes[f"nwp/{nwp_source}"] = nwp_out
356
-
357
- # *********************** Site Data *************************************
358
- # Add site-level yield history
359
- if self.include_site_yield_history:
360
- site_history = x["site"][:, : self.history_len + 1].float()
361
- site_history = site_history.reshape(site_history.shape[0], -1)
362
- modes["site"] = site_history
363
-
364
- # Add site-level yield history through PV encoder
365
- if self.include_pv:
366
- if self._target_key != "site":
367
- modes["site"] = self.pv_encoder(x)
368
- else:
369
- # Target is PV, so only take the history
370
- # Copy batch
371
- x_tmp = x.copy()
372
- x_tmp["site"] = x_tmp["site"][:, : self.history_len + 1]
373
- modes["site"] = self.pv_encoder(x_tmp)
374
-
375
- # *********************** GSP Data ************************************
376
- # add gsp yield history
377
- if self.include_gsp_yield_history:
378
- gsp_history = x["gsp"][:, : self.history_len].float()
379
- gsp_history = gsp_history.reshape(gsp_history.shape[0], -1)
380
- modes["gsp"] = gsp_history
381
-
382
- # ********************** Embedding of GSP/Site ID ********************
383
- if self.use_id_embedding:
384
- modes["id"] = self.embed(id)
385
-
386
- if self.include_sun:
387
- # Use only new direct keys
388
- sun = torch.cat(
389
- (
390
- x["solar_azimuth"],
391
- x["solar_elevation"],
392
- ),
393
- dim=1,
394
- ).float()
395
- sun = self.sun_fc1(sun)
396
- modes["sun"] = sun
397
-
398
- if self.include_time:
399
- time = torch.cat(
400
- (
401
- x[f"{self._target_key}_date_sin"],
402
- x[f"{self._target_key}_date_cos"],
403
- x[f"{self._target_key}_time_sin"],
404
- x[f"{self._target_key}_time_cos"],
405
- ),
406
- dim=1,
407
- ).float()
408
- time = self.time_fc1(time)
409
- modes["time"] = time
410
-
411
- out = self.output_network(modes)
412
-
413
- if self.use_quantile_regression:
414
- # Shape: batch_size, seq_length * num_quantiles
415
- out = out.reshape(out.shape[0], self.forecast_len, len(self.output_quantiles))
416
-
417
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/multimodal/readme.md DELETED
@@ -1,11 +0,0 @@
1
- ## Multimodal model architecture
2
-
3
- These models fusion models to predict GSP power output based on NWP, non-HRV satellite, GSP output history, solor coordinates, and GSP ID.
4
-
5
- The core model is `multimodel.Model`, and its architecture is shown in the diagram below.
6
-
7
- ![multimodal_model_diagram](https://github.com/openclimatefix/PVNet/assets/41546094/118393fa-52ec-4bfe-a0a3-268c94c25f1e)
8
-
9
- This model uses encoders which take 4D (time, channel, x, y) inputs of NWP and satellite and encode them into 1D feature vectors. Different encoders are contained inside `encoders`.
10
-
11
- Different choices for the fusion model are contained inside `linear_networks`.
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/multimodal/site_encoders/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Submodels to encode site-level PV data"""
 
 
pvnet/models/multimodal/site_encoders/basic_blocks.py DELETED
@@ -1,35 +0,0 @@
1
- """Basic blocks for PV-site encoders"""
2
- from abc import ABCMeta, abstractmethod
3
-
4
- from torch import nn
5
-
6
-
7
- class AbstractSitesEncoder(nn.Module, metaclass=ABCMeta):
8
- """Abstract class for encoder for output data from multiple PV sites.
9
-
10
- The encoder will take an input of shape (batch_size, sequence_length, num_sites)
11
- and return an output of shape (batch_size, out_features).
12
- """
13
-
14
- def __init__(
15
- self,
16
- sequence_length: int,
17
- num_sites: int,
18
- out_features: int,
19
- ):
20
- """Abstract class for PV site-level encoder.
21
-
22
- Args:
23
- sequence_length: The time sequence length of the data.
24
- num_sites: Number of PV sites in the input data.
25
- out_features: Number of output features.
26
- """
27
- super().__init__()
28
- self.sequence_length = sequence_length
29
- self.num_sites = num_sites
30
- self.out_features = out_features
31
-
32
- @abstractmethod
33
- def forward(self):
34
- """Run model forward"""
35
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/multimodal/site_encoders/encoders.py DELETED
@@ -1,284 +0,0 @@
1
- """Encoder modules for the site-level PV data.
2
-
3
- """
4
-
5
- import einops
6
- import torch
7
- from torch import nn
8
-
9
- from pvnet.models.multimodal.linear_networks.networks import ResFCNet2
10
- from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractSitesEncoder
11
-
12
-
13
- class SimpleLearnedAggregator(AbstractSitesEncoder):
14
- """A simple model which learns a different weighted-average across all PV sites for each GSP.
15
-
16
- Each sequence from each site is independently encodeded through some dense layers wih skip-
17
- connections, then the encoded form of each sequence is aggregated through a learned weighted-sum
18
- and finally put through more dense layers.
19
-
20
- This model was written to be a simplified version of a single-headed attention layer.
21
- """
22
-
23
- def __init__(
24
- self,
25
- sequence_length: int,
26
- num_sites: int,
27
- out_features: int,
28
- value_dim: int = 10,
29
- value_enc_resblocks: int = 2,
30
- final_resblocks: int = 2,
31
- ):
32
- """A simple sequence encoder and weighted-average model.
33
-
34
- Args:
35
- sequence_length: The time sequence length of the data.
36
- num_sites: Number of PV sites in the input data.
37
- out_features: Number of output features.
38
- value_dim: The number of features in each encoded sequence. Similar to the value
39
- dimension in single- or multi-head attention.
40
- value_dim: The number of features in each encoded sequence. Similar to the value
41
- dimension in single- or multi-head attention.
42
- value_enc_resblocks: Number of residual blocks in the value-encoder sub-network.
43
- final_resblocks: Number of residual blocks in the final sub-network.
44
- """
45
-
46
- super().__init__(sequence_length, num_sites, out_features)
47
-
48
- # Network used to encode each PV site sequence
49
- self._value_encoder = nn.Sequential(
50
- ResFCNet2(
51
- in_features=sequence_length,
52
- out_features=value_dim,
53
- fc_hidden_features=value_dim,
54
- n_res_blocks=value_enc_resblocks,
55
- res_block_layers=2,
56
- dropout_frac=0,
57
- ),
58
- )
59
-
60
- # The learned weighted average is stored in an embedding layer for ease of use
61
- self._attention_network = nn.Sequential(
62
- nn.Embedding(318, num_sites),
63
- nn.Softmax(dim=1),
64
- )
65
-
66
- # Network used to process weighted average
67
- self.output_network = ResFCNet2(
68
- in_features=value_dim,
69
- out_features=out_features,
70
- fc_hidden_features=value_dim,
71
- n_res_blocks=final_resblocks,
72
- res_block_layers=2,
73
- dropout_frac=0,
74
- )
75
-
76
- def _calculate_attention(self, x):
77
- gsp_ids = x["gsp_id"].squeeze().int()
78
- attention = self._attention_network(gsp_ids)
79
- return attention
80
-
81
- def _encode_value(self, x):
82
- # Shape: [batch size, sequence length, PV site]
83
- pv_site_seqs = x["pv"].float()
84
- batch_size = pv_site_seqs.shape[0]
85
-
86
- pv_site_seqs = pv_site_seqs.swapaxes(1, 2).flatten(0, 1)
87
-
88
- x_seq_enc = self._value_encoder(pv_site_seqs)
89
- x_seq_out = x_seq_enc.unflatten(0, (batch_size, self.num_sites))
90
- return x_seq_out
91
-
92
- def forward(self, x):
93
- """Run model forward"""
94
- # Output has shape: [batch size, num_sites, value_dim]
95
- encodeded_seqs = self._encode_value(x)
96
-
97
- # Calculate learned averaging weights
98
- attn_avg_weights = self._calculate_attention(x)
99
-
100
- # Take weighted average across num_sites
101
- value_weighted_avg = (encodeded_seqs * attn_avg_weights.unsqueeze(-1)).sum(dim=1)
102
-
103
- # Put through final processing layers
104
- x_out = self.output_network(value_weighted_avg)
105
-
106
- return x_out
107
-
108
-
109
- class SingleAttentionNetwork(AbstractSitesEncoder):
110
- """A simple attention-based model with a single multihead attention layer
111
-
112
- For the attention layer the query is based on the target alone, the key is based on the
113
- input ID and the recent input data, the value is based on the recent input data.
114
-
115
- """
116
-
117
- def __init__(
118
- self,
119
- sequence_length: int,
120
- num_sites: int,
121
- out_features: int,
122
- kdim: int = 10,
123
- id_embed_dim: int = 10,
124
- num_heads: int = 2,
125
- n_kv_res_blocks: int = 2,
126
- kv_res_block_layers: int = 2,
127
- use_id_in_value: bool = False,
128
- target_id_dim: int = 318,
129
- target_key_to_use: str = "gsp",
130
- input_key_to_use: str = "site",
131
- num_channels: int = 1,
132
- num_sites_in_inference: int = 1,
133
- ):
134
- """A simple attention-based model with a single multihead attention layer
135
-
136
- Args:
137
- sequence_length: The time sequence length of the data.
138
- num_sites: Number of sites in the input data.
139
- out_features: Number of output features. In this network this is also the embed and
140
- value dimension in the multi-head attention layer.
141
- kdim: The dimensions used the keys.
142
- id_embed_dim: Number of dimensiosn used in the site ID embedding layer(s).
143
- num_heads: Number of parallel attention heads. Note that `out_features` will be split
144
- across `num_heads` so `out_features` must be a multiple of `num_heads`.
145
- n_kv_res_blocks: Number of residual blocks to use in the key and value encoders.
146
- kv_res_block_layers: Number of fully-connected layers used in each residual block within
147
- the key and value encoders.
148
- use_id_in_value: Whether to use a site ID embedding in network used to produce the
149
- value for the attention layer.
150
- target_id_dim: The number of unique IDs.
151
- target_key_to_use: The key to use for the target in the attention layer.
152
- input_key_to_use: The key to use for the input in the attention layer.
153
- num_channels: Number of channels in the input data. For single site generation,
154
- this will be 1, as there is not channel dimension, for Sensors,
155
- this will probably be higher than that
156
- num_sites_in_inference: Number of sites to use in inference.
157
- This is used to determine the number of sites to use in the
158
- attention layer, for a single site, 1 works, while for multiple sites
159
- (such as multiple sensors), this would be higher than that
160
-
161
- """
162
- super().__init__(sequence_length, num_sites, out_features)
163
- self.sequence_length = sequence_length
164
- self.target_id_embedding = nn.Embedding(target_id_dim, out_features)
165
- self.site_id_embedding = nn.Embedding(num_sites, id_embed_dim)
166
- self._ids = nn.parameter.Parameter(torch.arange(num_sites), requires_grad=False)
167
- self.use_id_in_value = use_id_in_value
168
- self.target_key_to_use = target_key_to_use
169
- self.input_key_to_use = input_key_to_use
170
- self.num_channels = num_channels
171
- self.num_sites_in_inference = num_sites_in_inference
172
-
173
- if use_id_in_value:
174
- self.value_id_embedding = nn.Embedding(num_sites, id_embed_dim)
175
-
176
- self._value_encoder = nn.Sequential(
177
- ResFCNet2(
178
- in_features=sequence_length * self.num_channels
179
- + int(use_id_in_value) * id_embed_dim,
180
- out_features=out_features,
181
- fc_hidden_features=sequence_length * self.num_channels,
182
- n_res_blocks=n_kv_res_blocks,
183
- res_block_layers=kv_res_block_layers,
184
- dropout_frac=0,
185
- ),
186
- )
187
-
188
- self._key_encoder = nn.Sequential(
189
- ResFCNet2(
190
- in_features=id_embed_dim + sequence_length * self.num_channels,
191
- out_features=kdim,
192
- fc_hidden_features=id_embed_dim + sequence_length * self.num_channels,
193
- n_res_blocks=n_kv_res_blocks,
194
- res_block_layers=kv_res_block_layers,
195
- dropout_frac=0,
196
- ),
197
- )
198
-
199
- self.multihead_attn = nn.MultiheadAttention(
200
- embed_dim=out_features,
201
- kdim=kdim,
202
- vdim=out_features,
203
- num_heads=num_heads,
204
- batch_first=True,
205
- )
206
-
207
- def _encode_inputs(self, x):
208
- # Shape: [batch size, sequence length, number of sites]
209
- # Shape: [batch size, station_id, sequence length, channels]
210
- input_data = x[f"{self.input_key_to_use}"]
211
- if len(input_data.shape) == 2: # one site per sample
212
- input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D
213
- if len(input_data.shape) == 4: # Has multiple channels
214
- input_data = input_data[:, :, : self.sequence_length]
215
- input_data = einops.rearrange(input_data, "b id s c -> b (s c) id")
216
- else:
217
- input_data = input_data[:, : self.sequence_length]
218
- site_seqs = input_data.float()
219
- batch_size = site_seqs.shape[0]
220
- site_seqs = site_seqs.swapaxes(1, 2) # [batch size, Site ID, sequence length]
221
- return site_seqs, batch_size
222
-
223
- def _encode_query(self, x):
224
- # Select the first one
225
- if self.target_key_to_use == "gsp":
226
- # GSP seems to have a different structure
227
- ids = x[f"{self.target_key_to_use}_id"]
228
- else:
229
- ids = x[f"{self.input_key_to_use}_id"]
230
- ids = ids.int()
231
- query = self.target_id_embedding(ids).unsqueeze(1)
232
- return query
233
-
234
- def _encode_key(self, x):
235
- site_seqs, batch_size = self._encode_inputs(x)
236
-
237
- # site ID embeddings are the same for each sample
238
- site_id_embed = torch.tile(self.site_id_embedding(self._ids), (batch_size, 1, 1))
239
- # Each concated (site sequence, site ID embedding) is processed with encoder
240
- x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1)
241
- key = self._key_encoder(x_seq_in)
242
-
243
- # Reshape to [batch size, site, kdim]
244
- key = key.unflatten(0, (batch_size, self.num_sites))
245
- return key
246
-
247
- def _encode_value(self, x):
248
- site_seqs, batch_size = self._encode_inputs(x)
249
-
250
- if self.use_id_in_value:
251
- # site ID embeddings are the same for each sample
252
- site_id_embed = torch.tile(self.value_id_embedding(self._ids), (batch_size, 1, 1))
253
- # Each concated (site sequence, site ID embedding) is processed with encoder
254
- x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1)
255
- else:
256
- # Encode each site sequence independently
257
- x_seq_in = site_seqs.flatten(0, 1)
258
- value = self._value_encoder(x_seq_in)
259
-
260
- # Reshape to [batch size, site, vdim]
261
- value = value.unflatten(0, (batch_size, self.num_sites))
262
- return value
263
-
264
- def _attention_forward(self, x, average_attn_weights=True):
265
- query = self._encode_query(x)
266
- key = self._encode_key(x)
267
- value = self._encode_value(x)
268
- attn_output, attn_weights = self.multihead_attn(
269
- query, key, value, average_attn_weights=average_attn_weights
270
- )
271
-
272
- return attn_output, attn_weights
273
-
274
- def forward(self, x):
275
- """Run model forward"""
276
- # Do slicing here to only get history
277
- attn_output, attn_output_weights = self._attention_forward(x)
278
-
279
- # Reshape from [batch_size, 1, vdim] to [batch_size, vdim]
280
- x_out = attn_output.squeeze()
281
- if len(x_out.shape) == 1:
282
- x_out = x_out.unsqueeze(0)
283
-
284
- return x_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/multimodal/unimodal_teacher.py DELETED
@@ -1,447 +0,0 @@
1
- """The default composite model architecture for PVNet"""
2
-
3
- import glob
4
- from collections import OrderedDict
5
- from typing import Any, Optional
6
-
7
- import hydra
8
- import torch
9
- import torch.nn.functional as F
10
- from pyaml_env import parse_config
11
- from torch import nn
12
-
13
- import pvnet
14
- from pvnet.models.base_model import BaseModel
15
- from pvnet.models.multimodal.linear_networks.basic_blocks import AbstractLinearNetwork
16
- from pvnet.optimizers import AbstractOptimizer
17
-
18
-
19
- class Model(BaseModel):
20
- """Neural network which combines information from different sources
21
-
22
- The network is trained via unimodal teachers [1].
23
-
24
- Architecture is roughly as follows:
25
-
26
- - Satellite data, if included, is put through an encoder which transforms it from 4D, with time,
27
- channel, height, and width dimensions to become a 1D feature vector.
28
- - NWP, if included, is put through a similar encoder.
29
- - PV site-level data, if included, is put through an encoder which transforms it from 2D, with
30
- time and system-ID dimensions, to become a 1D feature vector.
31
- - The satellite features*, NWP features*, PV site-level features*, GSP ID embedding*, and sun
32
- paramters* are concatenated into a 1D feature vector and passed through another neural
33
- network to combine them and produce a forecast.
34
-
35
- * if included
36
- [1] https://arxiv.org/pdf/2305.01233.pdf
37
- """
38
-
39
- name = "unimodal_teacher"
40
-
41
- def __init__(
42
- self,
43
- output_network: AbstractLinearNetwork,
44
- output_quantiles: Optional[list[float]] = None,
45
- include_gsp_yield_history: bool = True,
46
- include_sun: bool = True,
47
- location_id_mapping: Optional[dict[Any, int]] = None,
48
- embedding_dim: Optional[int] = 16,
49
- forecast_minutes: int = 30,
50
- history_minutes: int = 60,
51
- optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
52
- mode_teacher_dict: dict = {},
53
- val_best: bool = True,
54
- cold_start: bool = True,
55
- enc_loss_frac: float = 0.3,
56
- adapt_batches: Optional[bool] = False,
57
- ):
58
- """Neural network which combines information from different sources.
59
-
60
- The network is trained via unimodal teachers [1].
61
-
62
- [1] https://arxiv.org/pdf/2305.01233.pdf
63
-
64
- Notes:
65
- In the args, where it says a module `m` is partially instantiated, it means that a
66
- normal pytorch module will be returned by running `mod = m(**kwargs)`. In this library,
67
- this partial instantiation is generally achieved using partial instantiation via hydra.
68
- However, the arg is still valid as long as `m(**kwargs)` returns a valid pytorch module
69
- - for example if `m` is a regular function.
70
-
71
- Args:
72
- output_network: A partially instatiated pytorch Module class used to combine the 1D
73
- features to produce the forecast.
74
- output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
75
- None the output is a single value.
76
- include_gsp_yield_history: Include GSP yield data.
77
- include_sun: Include sun azimuth and altitude data.
78
- location_id_mapping: A dictionary mapping the location ID to an integer. ID embedding is
79
- not used if this is not provided.
80
- embedding_dim: Number of embedding dimensions to use for GSP ID
81
- forecast_minutes: The amount of minutes that should be forecasted.
82
- history_minutes: The default amount of historical minutes that are used.
83
- optimizer: Optimizer factory function used for network.
84
- mode_teacher_dict: A dictionary of paths to different model checkpoint directories,
85
- which will be used as the unimodal teachers.
86
- val_best: Whether to load the model which performed best on the validation set. Else the
87
- last checkpoint is loaded.
88
- cold_start: Whether to train the uni-modal encoders from scratch. Else start them with
89
- weights from the uni-modal teachers.
90
- enc_loss_frac: Fraction of total loss attributed to the teacher encoders.
91
- adapt_batches: If set to true, we attempt to slice the batches to the expected shape for
92
- the model to use. This allows us to overprepare batches and slice from them for the
93
- data we need for a model run.
94
- """
95
-
96
- self.include_gsp_yield_history = include_gsp_yield_history
97
- self.include_sun = include_sun
98
- self.location_id_mapping = location_id_mapping
99
- self.embedding_dim = embedding_dim
100
- self.enc_loss_frac = enc_loss_frac
101
- self.include_sat = False
102
- self.include_nwp = False
103
- self.include_pv = False
104
- self.adapt_batches = adapt_batches
105
-
106
- self.use_id_embedding = location_id_mapping is not None
107
-
108
- if self.use_id_embedding:
109
- num_embeddings = max(location_id_mapping.values()) + 1
110
-
111
- # This is set but modified later based on the teachers
112
- self.add_image_embedding_channel = False
113
-
114
- super().__init__(
115
- history_minutes=history_minutes,
116
- forecast_minutes=forecast_minutes,
117
- optimizer=optimizer,
118
- output_quantiles=output_quantiles,
119
- target_key="gsp",
120
- )
121
-
122
- # Number of features expected by the output_network
123
- # Add to this as network pices are constructed
124
- fusion_input_features = 0
125
-
126
- self.teacher_models = torch.nn.ModuleDict()
127
- self.mode_teacher_dict = mode_teacher_dict
128
-
129
- for mode, path in mode_teacher_dict.items():
130
- # load teacher model and freeze its weights
131
- self.teacher_models[mode] = self.get_unimodal_encoder(path, True, val_best=val_best)
132
-
133
- for param in self.teacher_models[mode].parameters():
134
- param.requires_grad = False
135
-
136
- # Recreate model as student
137
- mode_student_model = self.get_unimodal_encoder(
138
- path, load_weights=(not cold_start), val_best=val_best
139
- )
140
-
141
- if mode == "sat":
142
- self.include_sat = True
143
- self.sat_sequence_len = mode_student_model.sat_sequence_len
144
- self.sat_encoder = mode_student_model.sat_encoder
145
-
146
- if mode_student_model.add_image_embedding_channel:
147
- self.sat_embed = mode_student_model.sat_embed
148
- self.add_image_embedding_channel = True
149
-
150
- fusion_input_features += self.sat_encoder.out_features
151
-
152
- elif mode == "site":
153
- self.include_pv = True
154
- self.site_encoder = mode_student_model.site_encoder
155
- fusion_input_features += self.site_encoder.out_features
156
-
157
- elif mode.startswith("nwp"):
158
- nwp_source = mode.removeprefix("nwp/")
159
-
160
- if not self.include_nwp:
161
- self.include_nwp = True
162
- self.nwp_encoders_dict = torch.nn.ModuleDict()
163
-
164
- if mode_student_model.add_image_embedding_channel:
165
- self.add_image_embedding_channel = True
166
- self.nwp_embed_dict = torch.nn.ModuleDict()
167
-
168
- self.nwp_encoders_dict[nwp_source] = mode_student_model.nwp_encoders_dict[
169
- nwp_source
170
- ]
171
-
172
- if self.add_image_embedding_channel:
173
- self.nwp_embed_dict[nwp_source] = mode_student_model.nwp_embed_dict[nwp_source]
174
-
175
- fusion_input_features += self.nwp_encoders_dict[nwp_source].out_features
176
-
177
- if self.embedding_dim:
178
- self.embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
179
- fusion_input_features += embedding_dim
180
-
181
- if self.include_sun:
182
- self.sun_fc1 = nn.Linear(
183
- in_features=2 * (self.forecast_len + self.history_len + 1),
184
- out_features=16,
185
- )
186
- fusion_input_features += 16
187
-
188
- if include_gsp_yield_history:
189
- fusion_input_features += self.history_len
190
-
191
- self.output_network = output_network(
192
- in_features=fusion_input_features,
193
- out_features=self.num_output_features,
194
- )
195
-
196
- self.save_hyperparameters()
197
-
198
- def get_unimodal_encoder(self, path, load_weights, val_best):
199
- """Load a model to function as a unimodal teacher"""
200
-
201
- model_config = parse_config(f"{path}/model_config.yaml")
202
-
203
- # Load the teacher model
204
- encoder = hydra.utils.instantiate(model_config)
205
-
206
- if load_weights:
207
- if val_best:
208
- # Only one epoch (best) saved per model
209
- files = glob.glob(f"{path}/epoch*.ckpt")
210
- assert len(files) == 1
211
- checkpoint = torch.load(files[0], map_location="cpu")
212
- else:
213
- checkpoint = torch.load(f"{path}/last.ckpt", map_location="cpu")
214
-
215
- encoder.load_state_dict(state_dict=checkpoint["state_dict"])
216
- return encoder
217
-
218
- def teacher_forward(self, x):
219
- """Run the teacher models and return their encodings"""
220
-
221
- if self.use_id_embedding:
222
- # eg: x['gsp_id] = [1] with location_id_mapping = {1:0}, would give [0]
223
- id = torch.tensor(
224
- [self.location_id_mapping[i.item()] for i in x[f"{self._target_key}_id"]],
225
- device=self.device,
226
- dtype=torch.int64,
227
- )
228
-
229
- modes = OrderedDict()
230
- for mode, teacher_model in self.teacher_models.items():
231
- # ******************* Satellite imagery *************************
232
- if mode == "sat":
233
- # Shape: batch_size, seq_length, channel, height, width
234
- sat_data = x["satellite_actual"][:, : teacher_model.sat_sequence_len]
235
- sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels
236
-
237
- if self.add_image_embedding_channel:
238
- sat_data = teacher_model.sat_embed(sat_data, id)
239
-
240
- modes[mode] = teacher_model.sat_encoder(sat_data)
241
-
242
- # *********************** NWP Data ************************************
243
- if mode.startswith("nwp"):
244
- nwp_source = mode.removeprefix("nwp/")
245
-
246
- # shape: batch_size, seq_len, n_chans, height, width
247
- nwp_data = x["nwp"][nwp_source]["nwp"].float()
248
- nwp_data = torch.swapaxes(nwp_data, 1, 2) # switch time and channels
249
- nwp_data = torch.clip(nwp_data, min=-50, max=50)
250
- if teacher_model.add_image_embedding_channel:
251
- nwp_data = teacher_model.nwp_embed_dict[nwp_source](nwp_data, id)
252
-
253
- nwp_out = teacher_model.nwp_encoders_dict[nwp_source](nwp_data)
254
- modes[mode] = nwp_out
255
-
256
- # *********************** PV Data *************************************
257
- # Add site-level PV yield
258
- if mode == "site":
259
- modes[mode] = teacher_model.site_encoder(x)
260
-
261
- return modes
262
-
263
- def forward(self, x, return_modes=False):
264
- """Run model forward"""
265
-
266
- if self.adapt_batches:
267
- x = self._adapt_batch(x)
268
-
269
- if self.use_id_embedding:
270
- # eg: x['gsp_id] = [1] with location_id_mapping = {1:0}, would give [0]
271
- id = torch.tensor(
272
- [self.location_id_mapping[i.item()] for i in x[f"{self._target_key}_id"]],
273
- device=self.device,
274
- dtype=torch.int64,
275
- )
276
-
277
- modes = OrderedDict()
278
- # ******************* Satellite imagery *************************
279
- if self.include_sat:
280
- # Shape: batch_size, seq_length, channel, height, width
281
- sat_data = x["satellite_actual"][:, : self.sat_sequence_len]
282
- sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels
283
-
284
- if self.add_image_embedding_channel:
285
- sat_data = self.sat_embed(sat_data, id)
286
- modes["sat"] = self.sat_encoder(sat_data)
287
-
288
- # *********************** NWP Data ************************************
289
- if self.include_nwp:
290
- # Loop through potentially many NMPs
291
- for nwp_source in self.nwp_encoders_dict:
292
- # shape: batch_size, seq_len, n_chans, height, width
293
- nwp_data = x["nwp"][nwp_source]["nwp"].float()
294
- nwp_data = torch.swapaxes(nwp_data, 1, 2) # switch time and channels
295
- # Some NWP variables can overflow into NaNs when normalised if they have extreme
296
- # tails
297
- nwp_data = torch.clip(nwp_data, min=-50, max=50)
298
-
299
- if self.add_image_embedding_channel:
300
- nwp_data = self.nwp_embed_dict[nwp_source](nwp_data, id)
301
-
302
- nwp_out = self.nwp_encoders_dict[nwp_source](nwp_data)
303
- modes[f"nwp/{nwp_source}"] = nwp_out
304
-
305
- # *********************** PV Data *************************************
306
- # Add site-level PV yield
307
- if self.include_pv:
308
- if self._target_key != "site":
309
- modes["site"] = self.site_encoder(x)
310
- else:
311
- # Target is PV, so only take the history
312
- pv_history = x["pv"][:, : self.history_len].float()
313
- modes["site"] = self.site_encoder(pv_history)
314
-
315
- # *********************** GSP Data ************************************
316
- # add gsp yield history
317
- if self.include_gsp_yield_history:
318
- gsp_history = x["gsp"][:, : self.history_len].float()
319
- gsp_history = gsp_history.reshape(gsp_history.shape[0], -1)
320
- modes["gsp"] = gsp_history
321
-
322
- # ********************** Embedding of GSP ID ********************
323
- if self.use_id_embedding:
324
- modes["id"] = self.embed(id)
325
-
326
- if self.include_sun:
327
- # Use only new direct keys
328
- sun = torch.cat(
329
- (
330
- x["solar_azimuth"],
331
- x["solar_elevation"],
332
- ),
333
- dim=1,
334
- ).float()
335
- sun = self.sun_fc1(sun)
336
- modes["sun"] = sun
337
-
338
- out = self.output_network(modes)
339
-
340
- if self.use_quantile_regression:
341
- # Shape: batch_size, seq_length * num_quantiles
342
- out = out.reshape(out.shape[0], self.forecast_len, len(self.output_quantiles))
343
-
344
- if return_modes:
345
- return out, modes
346
- else:
347
- return out
348
-
349
- def _calculate_teacher_loss(self, modes, teacher_modes):
350
- enc_losses = {}
351
- for m, enc in teacher_modes.items():
352
- enc_losses[f"enc_loss/{m}"] = F.l1_loss(enc, modes[m])
353
- enc_losses["enc_loss/total"] = sum([v for k, v in enc_losses.items()])
354
- return enc_losses
355
-
356
- def training_step(self, batch, batch_idx):
357
- """Run training step"""
358
- y_hat, modes = self.forward(batch, return_modes=True)
359
- y = batch[self._target_key][:, -self.forecast_len :, 0]
360
-
361
- losses = self._calculate_common_losses(y, y_hat)
362
-
363
- teacher_modes = self.teacher_forward(batch)
364
- teacher_loss = self._calculate_teacher_loss(modes, teacher_modes)
365
- losses.update(teacher_loss)
366
-
367
- if self.use_quantile_regression:
368
- opt_target = losses["quantile_loss"]
369
- else:
370
- opt_target = losses["MAE"]
371
-
372
- t_loss = teacher_loss["enc_loss/total"]
373
-
374
- # The scales of the two losses
375
- l_s = opt_target.detach()
376
- tl_s = max(t_loss.detach(), 1e-9)
377
-
378
- # opt_target = t_loss/tl_s * l_s * self.enc_loss_frac + opt_target * (1-self.enc_loss_frac)
379
- losses["opt_loss"] = t_loss / tl_s * l_s * self.enc_loss_frac + opt_target * (
380
- 1 - self.enc_loss_frac
381
- )
382
-
383
- losses = {f"{k}/train": v for k, v in losses.items()}
384
- self._training_accumulate_log(batch, batch_idx, losses, y_hat)
385
-
386
- return losses["opt_loss/train"]
387
-
388
- def convert_to_multimodal_model(self, config):
389
- """Convert the model into a multimodal model class whilst preserving weights"""
390
- config = config.copy()
391
-
392
- if "cold_start" in config:
393
- del config["cold_start"]
394
-
395
- config["_target_"] = "pvnet.models.multimodal.multimodal.Model"
396
-
397
- sources = []
398
- for mode, path in config["mode_teacher_dict"].items():
399
- model_config = parse_config(f"{path}/model_config.yaml")
400
-
401
- if mode.startswith("nwp"):
402
- nwp_source = mode.removeprefix("nwp/")
403
- if "nwp_encoders_dict" in config:
404
- for key in ["nwp_encoders_dict", "nwp_history_minutes", "nwp_forecast_minutes"]:
405
- config[key][nwp_source] = model_config[key][nwp_source]
406
- sources.append("nwp")
407
- else:
408
- for key in ["nwp_encoders_dict", "nwp_history_minutes", "nwp_forecast_minutes"]:
409
- config[key] = {nwp_source: model_config[key][nwp_source]}
410
- config["add_image_embedding_channel"] = model_config["add_image_embedding_channel"]
411
-
412
- elif mode == "sat":
413
- for key in [
414
- "sat_encoder",
415
- "add_image_embedding_channel",
416
- "min_sat_delay_minutes",
417
- "sat_history_minutes",
418
- ]:
419
- config[key] = model_config[key]
420
- sources.append("sat")
421
-
422
- elif mode == "site":
423
- for key in ["site_encoder", "site_history_minutes"]:
424
- config[key] = model_config[key]
425
- sources.append("site")
426
-
427
- del config["mode_teacher_dict"]
428
-
429
- # Load the teacher model
430
- multimodal_model = hydra.utils.instantiate(config)
431
-
432
- if "sat" in sources:
433
- multimodal_model.sat_encoder.load_state_dict(self.sat_encoder.state_dict())
434
- if "nwp" in sources:
435
- multimodal_model.nwp_encoders_dict.load_state_dict(self.nwp_encoders_dict.state_dict())
436
- if "site" in sources:
437
- multimodal_model.site_encoder.load_state_dict(self.site_encoder.state_dict())
438
-
439
- multimodal_model.output_network.load_state_dict(self.output_network.state_dict())
440
-
441
- if self.embedding_dim:
442
- multimodal_model.embed.load_state_dict(self.embed.state_dict())
443
-
444
- if self.include_sun:
445
- multimodal_model.sun_fc1.load_state_dict(self.sun_fc1.state_dict())
446
-
447
- return multimodal_model, config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/models/utils.py DELETED
@@ -1,123 +0,0 @@
1
- """Utility functions"""
2
-
3
- import logging
4
-
5
- import numpy as np
6
- import torch
7
-
8
- logger = logging.getLogger(__name__)
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- class PredAccumulator:
14
- """A class for accumulating y-predictions using grad accumulation and small batch size.
15
-
16
- Attributes:
17
- _y_hats (list[torch.Tensor]): List of prediction tensors
18
- """
19
-
20
- def __init__(self):
21
- """Prediction accumulator"""
22
- self._y_hats = []
23
-
24
- def __bool__(self):
25
- return len(self._y_hats) > 0
26
-
27
- def append(self, y_hat: torch.Tensor):
28
- """Append a sub-batch of predictions"""
29
- self._y_hats.append(y_hat)
30
-
31
- def flush(self) -> torch.Tensor:
32
- """Return all appended predictions as single tensor and remove from accumulated store."""
33
- y_hat = torch.cat(self._y_hats, dim=0)
34
- self._y_hats = []
35
- return y_hat
36
-
37
-
38
- class DictListAccumulator:
39
- """Abstract class for accumulating dictionaries of lists"""
40
-
41
- @staticmethod
42
- def _dict_list_append(d1, d2):
43
- for k, v in d2.items():
44
- d1[k].append(v)
45
-
46
- @staticmethod
47
- def _dict_init_list(d):
48
- return {k: [v] for k, v in d.items()}
49
-
50
-
51
- class MetricAccumulator(DictListAccumulator):
52
- """Dictionary of metrics accumulator.
53
-
54
- A class for accumulating, and finding the mean of logging metrics when using grad
55
- accumulation and the batch size is small.
56
-
57
- Attributes:
58
- _metrics (Dict[str, list[float]]): Dictionary containing lists of metrics.
59
- """
60
-
61
- def __init__(self):
62
- """Dictionary of metrics accumulator."""
63
- self._metrics = {}
64
-
65
- def __bool__(self):
66
- return self._metrics != {}
67
-
68
- def append(self, loss_dict: dict[str, float]):
69
- """Append lictionary of metrics to self"""
70
- if not self:
71
- self._metrics = self._dict_init_list(loss_dict)
72
- else:
73
- self._dict_list_append(self._metrics, loss_dict)
74
-
75
- def flush(self) -> dict[str, float]:
76
- """Calculate mean of all accumulated metrics and clear"""
77
- mean_metrics = {k: np.mean(v) for k, v in self._metrics.items()}
78
- self._metrics = {}
79
- return mean_metrics
80
-
81
-
82
- class BatchAccumulator(DictListAccumulator):
83
- """A class for accumulating batches when using grad accumulation and the batch size is small.
84
-
85
- Attributes:
86
- _batches (Dict[str, list[torch.Tensor]]): Dictionary containing lists of metrics.
87
- """
88
-
89
- def __init__(self, key_to_keep: str = "gsp"):
90
- """Batch accumulator"""
91
- self._batches = {}
92
- self.key_to_keep = key_to_keep
93
-
94
- def __bool__(self):
95
- return self._batches != {}
96
-
97
- # @staticmethod
98
- def _filter_batch_dict(self, d):
99
- keep_keys = [
100
- self.key_to_keep,
101
- f"{self.key_to_keep}_id",
102
- f"{self.key_to_keep}_t0_idx",
103
- f"{self.key_to_keep}_time_utc",
104
- ]
105
- return {k: v for k, v in d.items() if k in keep_keys}
106
-
107
- def append(self, batch: dict[str, list[torch.Tensor]]):
108
- """Append batch to self"""
109
- if not self:
110
- self._batches = self._dict_init_list(self._filter_batch_dict(batch))
111
- else:
112
- self._dict_list_append(self._batches, self._filter_batch_dict(batch))
113
-
114
- def flush(self) -> dict[str, list[torch.Tensor]]:
115
- """Concatenate all accumulated batches, return, and clear self"""
116
- batch = {}
117
- for k, v in self._batches.items():
118
- if k == f"{self.key_to_keep}_t0_idx":
119
- batch[k] = v[0]
120
- else:
121
- batch[k] = torch.cat(v, dim=0)
122
- self._batches = {}
123
- return batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/optimizers.py DELETED
@@ -1,200 +0,0 @@
1
- """Optimizer factory-function classes.
2
- """
3
-
4
- from abc import ABC, abstractmethod
5
-
6
- import torch
7
-
8
-
9
- class AbstractOptimizer(ABC):
10
- """Abstract class for optimizer
11
-
12
- Optimizer classes will be used by model like:
13
- > OptimizerGenerator = AbstractOptimizer()
14
- > optimizer = OptimizerGenerator(model)
15
- The returned object `optimizer` must be something that may be returned by `pytorch_lightning`'s
16
- `configure_optimizers()` method.
17
- See :
18
- https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers
19
-
20
- """
21
-
22
- @abstractmethod
23
- def __call__(self):
24
- """Abstract call"""
25
- pass
26
-
27
-
28
- class Adam(AbstractOptimizer):
29
- """Adam optimizer"""
30
-
31
- def __init__(self, lr=0.0005, **kwargs):
32
- """Adam optimizer"""
33
- self.lr = lr
34
- self.kwargs = kwargs
35
-
36
- def __call__(self, model):
37
- """Return optimizer"""
38
- return torch.optim.Adam(model.parameters(), lr=self.lr, **self.kwargs)
39
-
40
-
41
- class AdamW(AbstractOptimizer):
42
- """AdamW optimizer"""
43
-
44
- def __init__(self, lr=0.0005, **kwargs):
45
- """AdamW optimizer"""
46
- self.lr = lr
47
- self.kwargs = kwargs
48
-
49
- def __call__(self, model):
50
- """Return optimizer"""
51
- return torch.optim.AdamW(model.parameters(), lr=self.lr, **self.kwargs)
52
-
53
-
54
- def find_submodule_parameters(model, search_modules):
55
- """Finds all parameters within given submodule types
56
-
57
- Args:
58
- model: torch Module to search through
59
- search_modules: List of submodule types to search for
60
- """
61
- if isinstance(model, search_modules):
62
- return model.parameters()
63
-
64
- children = list(model.children())
65
- if len(children) == 0:
66
- return []
67
- else:
68
- params = []
69
- for c in children:
70
- params += find_submodule_parameters(c, search_modules)
71
- return params
72
-
73
-
74
- def find_other_than_submodule_parameters(model, ignore_modules):
75
- """Finds all parameters not with given submodule types
76
-
77
- Args:
78
- model: torch Module to search through
79
- ignore_modules: List of submodule types to ignore
80
- """
81
- if isinstance(model, ignore_modules):
82
- return []
83
-
84
- children = list(model.children())
85
- if len(children) == 0:
86
- return model.parameters()
87
- else:
88
- params = []
89
- for c in children:
90
- params += find_other_than_submodule_parameters(c, ignore_modules)
91
- return params
92
-
93
-
94
- class EmbAdamWReduceLROnPlateau(AbstractOptimizer):
95
- """AdamW optimizer and reduce on plateau scheduler"""
96
-
97
- def __init__(
98
- self, lr=0.0005, weight_decay=0.01, patience=3, factor=0.5, threshold=2e-4, **opt_kwargs
99
- ):
100
- """AdamW optimizer and reduce on plateau scheduler"""
101
- self.lr = lr
102
- self.weight_decay = weight_decay
103
- self.patience = patience
104
- self.factor = factor
105
- self.threshold = threshold
106
- self.opt_kwargs = opt_kwargs
107
-
108
- def __call__(self, model):
109
- """Return optimizer"""
110
-
111
- search_modules = (torch.nn.Embedding,)
112
-
113
- no_decay = find_submodule_parameters(model, search_modules)
114
- decay = find_other_than_submodule_parameters(model, search_modules)
115
-
116
- optim_groups = [
117
- {"params": decay, "weight_decay": self.weight_decay},
118
- {"params": no_decay, "weight_decay": 0.0},
119
- ]
120
- opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs)
121
-
122
- sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
123
- opt,
124
- factor=self.factor,
125
- patience=self.patience,
126
- threshold=self.threshold,
127
- )
128
- sch = {
129
- "scheduler": sch,
130
- "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
131
- }
132
- return [opt], [sch]
133
-
134
-
135
- class AdamWReduceLROnPlateau(AbstractOptimizer):
136
- """AdamW optimizer and reduce on plateau scheduler"""
137
-
138
- def __init__(
139
- self, lr=0.0005, patience=3, factor=0.5, threshold=2e-4, step_freq=None, **opt_kwargs
140
- ):
141
- """AdamW optimizer and reduce on plateau scheduler"""
142
- self._lr = lr
143
- self.patience = patience
144
- self.factor = factor
145
- self.threshold = threshold
146
- self.step_freq = step_freq
147
- self.opt_kwargs = opt_kwargs
148
-
149
- def _call_multi(self, model):
150
- remaining_params = {k: p for k, p in model.named_parameters()}
151
-
152
- group_args = []
153
-
154
- for key in self._lr.keys():
155
- if key == "default":
156
- continue
157
-
158
- submodule_params = []
159
- for param_name in list(remaining_params.keys()):
160
- if param_name.startswith(key):
161
- submodule_params += [remaining_params.pop(param_name)]
162
-
163
- group_args += [{"params": submodule_params, "lr": self._lr[key]}]
164
-
165
- remaining_params = [p for k, p in remaining_params.items()]
166
- group_args += [{"params": remaining_params}]
167
-
168
- opt = torch.optim.AdamW(
169
- group_args, lr=self._lr["default"] if model.lr is None else model.lr, **self.opt_kwargs
170
- )
171
- sch = {
172
- "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
173
- opt,
174
- factor=self.factor,
175
- patience=self.patience,
176
- threshold=self.threshold,
177
- ),
178
- "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
179
- }
180
-
181
- return [opt], [sch]
182
-
183
- def __call__(self, model):
184
- """Return optimizer"""
185
- if not isinstance(self._lr, float):
186
- return self._call_multi(model)
187
- else:
188
- default_lr = self._lr if model.lr is None else model.lr
189
- opt = torch.optim.AdamW(model.parameters(), lr=default_lr, **self.opt_kwargs)
190
- sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
191
- opt,
192
- factor=self.factor,
193
- patience=self.patience,
194
- threshold=self.threshold,
195
- )
196
- sch = {
197
- "scheduler": sch,
198
- "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
199
- }
200
- return [opt], [sch]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/training.py DELETED
@@ -1,183 +0,0 @@
1
- """Training"""
2
- import os
3
- import shutil
4
- from typing import Optional
5
-
6
- import hydra
7
- import torch
8
- from lightning.pytorch import (
9
- Callback,
10
- LightningDataModule,
11
- LightningModule,
12
- Trainer,
13
- seed_everything,
14
- )
15
- from lightning.pytorch.callbacks import ModelCheckpoint
16
- from lightning.pytorch.loggers import Logger
17
- from lightning.pytorch.loggers.wandb import WandbLogger
18
- from omegaconf import DictConfig, OmegaConf
19
-
20
- from pvnet import utils
21
-
22
- log = utils.get_logger(__name__)
23
-
24
- torch.set_default_dtype(torch.float32)
25
-
26
-
27
- def _callbacks_to_phase(callbacks, phase):
28
- for c in callbacks:
29
- if hasattr(c, "switch_phase"):
30
- c.switch_phase(phase)
31
-
32
-
33
- def resolve_monitor_loss(output_quantiles):
34
- """Return the desired metric to monitor based on whether quantile regression is being used.
35
-
36
- The adds the option to use something like:
37
- monitor: "${resolve_monitor_loss:${model.output_quantiles}}"
38
-
39
- in early stopping and model checkpoint callbacks so the callbacks config does not need to be
40
- modified depending on whether quantile regression is being used or not.
41
- """
42
- if output_quantiles is None:
43
- return "MAE/val"
44
- else:
45
- return "quantile_loss/val"
46
-
47
-
48
- OmegaConf.register_new_resolver("resolve_monitor_loss", resolve_monitor_loss)
49
-
50
-
51
- def train(config: DictConfig) -> Optional[float]:
52
- """Contains training pipeline.
53
-
54
- Instantiates all PyTorch Lightning objects from config.
55
-
56
- Args:
57
- config (DictConfig): Configuration composed by Hydra.
58
-
59
- Returns:
60
- Optional[float]: Metric score for hyperparameter optimization.
61
- """
62
-
63
- # Set seed for random number generators in pytorch, numpy and python.random
64
- if "seed" in config:
65
- seed_everything(config.seed, workers=True)
66
-
67
- # Init lightning datamodule
68
- log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
69
- datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
70
-
71
- # Init lightning model
72
- log.info(f"Instantiating model <{config.model._target_}>")
73
- model: LightningModule = hydra.utils.instantiate(config.model)
74
-
75
- # Init lightning loggers
76
- loggers: list[Logger] = []
77
- if "logger" in config:
78
- for _, lg_conf in config.logger.items():
79
- if "_target_" in lg_conf:
80
- log.info(f"Instantiating logger <{lg_conf._target_}>")
81
- loggers.append(hydra.utils.instantiate(lg_conf))
82
-
83
- # Init lightning callbacks
84
- callbacks: list[Callback] = []
85
- if "callbacks" in config:
86
- for _, cb_conf in config.callbacks.items():
87
- if "_target_" in cb_conf:
88
- log.info(f"Instantiating callback <{cb_conf._target_}>")
89
- callbacks.append(hydra.utils.instantiate(cb_conf))
90
-
91
- # Align the wandb id with the checkpoint path
92
- # - only works if wandb logger and model checkpoint used
93
- # - this makes it easy to push the model to huggingface
94
- use_wandb_logger = False
95
- for logger in loggers:
96
- log.info(f"{logger}")
97
- if isinstance(logger, WandbLogger):
98
- use_wandb_logger = True
99
- wandb_logger = logger
100
- break
101
-
102
- if use_wandb_logger:
103
- for callback in callbacks:
104
- log.info(f"{callback}")
105
- if isinstance(callback, ModelCheckpoint):
106
- # Need to call the .experiment property to initialise the logger
107
- wandb_logger.experiment
108
- callback.dirpath = "/".join(
109
- callback.dirpath.split("/")[:-1] + [wandb_logger.version]
110
- )
111
- # Also save model config here - this makes for easy model push to huggingface
112
- os.makedirs(callback.dirpath, exist_ok=True)
113
- OmegaConf.save(config.model, f"{callback.dirpath}/model_config.yaml")
114
-
115
- # Similarly save the data config
116
- data_config = config.datamodule.configuration
117
- if data_config is None:
118
- # Data config can be none if using presaved batches. We go to the presaved
119
- # batches to get the data config
120
- data_config = f"{config.datamodule.sample_dir}/data_configuration.yaml"
121
-
122
- assert os.path.isfile(data_config), f"Data config file not found: {data_config}"
123
- shutil.copyfile(data_config, f"{callback.dirpath}/data_config.yaml")
124
-
125
- # upload configuration up to wandb
126
- OmegaConf.save(config, "./experiment_config.yaml")
127
- wandb_logger.experiment.save(
128
- f"{callback.dirpath}/data_config.yaml", callback.dirpath
129
- )
130
- wandb_logger.experiment.save("./experiment_config.yaml")
131
-
132
- break
133
-
134
- should_pretrain = False
135
- for c in callbacks:
136
- should_pretrain |= hasattr(c, "training_phase") and c.training_phase == "pretrain"
137
-
138
- if should_pretrain:
139
- _callbacks_to_phase(callbacks, "pretrain")
140
-
141
- trainer: Trainer = hydra.utils.instantiate(
142
- config.trainer,
143
- logger=loggers,
144
- _convert_="partial",
145
- callbacks=callbacks,
146
- )
147
-
148
- # TODO: remove this option
149
- if should_pretrain:
150
- # Pre-train the model
151
- raise NotImplementedError("Pre-training is not yet supported")
152
- # The parameter `block_nwp_and_sat` is not available in data-sampler
153
- # If pretraining is re-supported in the future it is likely any pre-training logic should
154
- # go here or perhaps in the callbacks
155
- # datamodule.block_nwp_and_sat = True
156
-
157
- trainer.fit(model=model, datamodule=datamodule)
158
-
159
- _callbacks_to_phase(callbacks, "main")
160
-
161
- trainer.should_stop = False
162
-
163
- # Train the model completely
164
- trainer.fit(model=model, datamodule=datamodule)
165
-
166
- # Make sure everything closed properly
167
- log.info("Finalizing!")
168
- utils.finish(
169
- config=config,
170
- model=model,
171
- datamodule=datamodule,
172
- trainer=trainer,
173
- callbacks=callbacks,
174
- loggers=loggers,
175
- )
176
-
177
- # Print path to best checkpoint
178
- log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}")
179
-
180
- # Return metric score for hyperparameter optimization
181
- optimized_metric = config.get("optimized_metric")
182
- if optimized_metric:
183
- return trainer.callback_metrics[optimized_metric]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pvnet/utils.py DELETED
@@ -1,321 +0,0 @@
1
- """Utils"""
2
- import logging
3
- import warnings
4
- from collections.abc import Sequence
5
- from typing import Optional
6
-
7
- import lightning.pytorch as pl
8
- import matplotlib.pyplot as plt
9
- import pandas as pd
10
- import pylab
11
- import rich.syntax
12
- import rich.tree
13
- import xarray as xr
14
- from lightning.pytorch.loggers import Logger
15
- from lightning.pytorch.utilities import rank_zero_only
16
- from ocf_data_sampler.select.location import Location
17
- from omegaconf import DictConfig, OmegaConf
18
-
19
-
20
- def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
21
- """Initializes multi-GPU-friendly python logger."""
22
-
23
- logger = logging.getLogger(name)
24
- logger.setLevel(level)
25
-
26
- # this ensures all logging levels get marked with the rank zero decorator
27
- # otherwise logs would get multiplied for each GPU process in multi-GPU setup
28
- for level in (
29
- "debug",
30
- "info",
31
- "warning",
32
- "error",
33
- "exception",
34
- "fatal",
35
- "critical",
36
- ):
37
- setattr(logger, level, rank_zero_only(getattr(logger, level)))
38
-
39
- return logger
40
-
41
-
42
- class GSPLocationLookup:
43
- """Query object for GSP location from GSP ID"""
44
-
45
- def __init__(self, x_osgb: xr.DataArray, y_osgb: xr.DataArray):
46
- """Query object for GSP location from GSP ID
47
-
48
- Args:
49
- x_osgb: DataArray of the OSGB x-coordinate for any given GSP ID
50
- y_osgb: DataArray of the OSGB y-coordinate for any given GSP ID
51
-
52
- """
53
- self.x_osgb = x_osgb
54
- self.y_osgb = y_osgb
55
-
56
- def __call__(self, gsp_id: int) -> Location:
57
- """Returns the locations for the input GSP IDs.
58
-
59
- Args:
60
- gsp_id: Integer ID of the GSP
61
- """
62
- return Location(
63
- x=self.x_osgb.sel(gsp_id=gsp_id).item(),
64
- y=self.y_osgb.sel(gsp_id=gsp_id).item(),
65
- id=gsp_id,
66
- )
67
-
68
-
69
- class SiteLocationLookup:
70
- """Query object for site location from site ID"""
71
-
72
- def __init__(self, long: xr.DataArray, lat: xr.DataArray):
73
- """Query object for site location from site ID
74
-
75
- Args:
76
- long: DataArray of the longitude coordinates for any given site ID
77
- lat: DataArray of the latitude coordinates for any given site ID
78
-
79
- """
80
- self.longitude = long
81
- self.latitude = lat
82
-
83
- def __call__(self, site_id: int) -> Location:
84
- """Returns the locations for the input site IDs.
85
-
86
- Args:
87
- site_id: Integer ID of the site
88
- """
89
- return Location(
90
- coordinate_system="lon_lat",
91
- x=self.longitude.sel(pv_system_id=site_id).item(),
92
- y=self.latitude.sel(pv_system_id=site_id).item(),
93
- id=site_id,
94
- )
95
-
96
-
97
- def extras(config: DictConfig) -> None:
98
- """A couple of optional utilities.
99
-
100
- Controlled by main config file:
101
- - disabling warnings
102
- - easier access to debug mode
103
- - forcing debug friendly configuration
104
-
105
- Modifies DictConfig in place.
106
-
107
- Args:
108
- config (DictConfig): Configuration composed by Hydra.
109
- """
110
-
111
- log = get_logger()
112
-
113
- # enable adding new keys to config
114
- OmegaConf.set_struct(config, False)
115
-
116
- # disable python warnings if <config.ignore_warnings=True>
117
- if config.get("ignore_warnings"):
118
- log.info("Disabling python warnings! <config.ignore_warnings=True>")
119
- warnings.filterwarnings("ignore")
120
-
121
- # set <config.trainer.fast_dev_run=True> if <config.debug=True>
122
- if config.get("debug"):
123
- log.info("Running in debug mode! <config.debug=True>")
124
- config.trainer.fast_dev_run = True
125
-
126
- # force debugger friendly configuration if <config.trainer.fast_dev_run=True>
127
- if config.trainer.get("fast_dev_run"):
128
- log.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
129
- # Debuggers don't like GPUs or multiprocessing
130
- if config.trainer.get("gpus"):
131
- config.trainer.gpus = 0
132
- if config.datamodule.get("pin_memory"):
133
- config.datamodule.pin_memory = False
134
- if config.datamodule.get("num_workers"):
135
- config.datamodule.num_workers = 0
136
-
137
- # disable adding new keys to config
138
- OmegaConf.set_struct(config, True)
139
-
140
-
141
- @rank_zero_only
142
- def print_config(
143
- config: DictConfig,
144
- fields: Sequence[str] = (
145
- "trainer",
146
- "model",
147
- "datamodule",
148
- "callbacks",
149
- "logger",
150
- "seed",
151
- ),
152
- resolve: bool = True,
153
- ) -> None:
154
- """Prints content of DictConfig using Rich library and its tree structure.
155
-
156
- Args:
157
- config (DictConfig): Configuration composed by Hydra.
158
- fields (Sequence[str], optional): Determines which main fields from config will
159
- be printed and in what order.
160
- resolve (bool, optional): Whether to resolve reference fields of DictConfig.
161
- """
162
-
163
- style = "dim"
164
- tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
165
-
166
- for field in fields:
167
- branch = tree.add(field, style=style, guide_style=style)
168
-
169
- config_section = config.get(field)
170
- branch_content = str(config_section)
171
- if isinstance(config_section, DictConfig):
172
- branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
173
-
174
- branch.add(rich.syntax.Syntax(branch_content, "yaml"))
175
-
176
- rich.print(tree)
177
-
178
- with open("config_tree.txt", "w") as fp:
179
- rich.print(tree, file=fp)
180
-
181
-
182
- def empty(*args, **kwargs):
183
- """Returns nothing"""
184
- pass
185
-
186
-
187
- @rank_zero_only
188
- def log_hyperparameters(
189
- config: DictConfig,
190
- model: pl.LightningModule,
191
- datamodule: pl.LightningDataModule,
192
- trainer: pl.Trainer,
193
- callbacks: list[pl.Callback],
194
- logger: list[Logger],
195
- ) -> None:
196
- """This method controls which parameters from Hydra config are saved by Lightning loggers.
197
-
198
- Additionaly saves:
199
- - number of trainable model parameters
200
- """
201
-
202
- hparams = {}
203
-
204
- # choose which parts of hydra config will be saved to loggers
205
- hparams["trainer"] = config["trainer"]
206
- hparams["model"] = config["model"]
207
- hparams["datamodule"] = config["datamodule"]
208
- if "seed" in config:
209
- hparams["seed"] = config["seed"]
210
- if "callbacks" in config:
211
- hparams["callbacks"] = config["callbacks"]
212
-
213
- # save number of model parameters
214
- hparams["model/params_total"] = sum(p.numel() for p in model.parameters())
215
- hparams["model/params_trainable"] = sum(
216
- p.numel() for p in model.parameters() if p.requires_grad
217
- )
218
- hparams["model/params_not_trainable"] = sum(
219
- p.numel() for p in model.parameters() if not p.requires_grad
220
- )
221
-
222
- # send hparams to all loggers
223
- trainer.logger.log_hyperparams(hparams)
224
-
225
- # disable logging any more hyperparameters for all loggers
226
- # this is just a trick to prevent trainer from logging hparams of model,
227
- # since we already did that above
228
- trainer.logger.log_hyperparams = empty
229
-
230
-
231
- def finish(
232
- config: DictConfig,
233
- model: pl.LightningModule,
234
- datamodule: pl.LightningDataModule,
235
- trainer: pl.Trainer,
236
- callbacks: list[pl.Callback],
237
- loggers: list[Logger],
238
- ) -> None:
239
- """Makes sure everything closed properly."""
240
-
241
- # without this sweeps with wandb logger might crash!
242
- if any([isinstance(logger, pl.loggers.wandb.WandbLogger) for logger in loggers]):
243
- import wandb
244
-
245
- wandb.finish()
246
-
247
-
248
- def plot_batch_forecasts(
249
- batch,
250
- y_hat,
251
- batch_idx=None,
252
- quantiles=None,
253
- key_to_plot: str = "gsp",
254
- timesteps_to_plot: Optional[list[int]] = None,
255
- ):
256
- """Plot a batch of data and the forecast from that batch"""
257
-
258
- def _get_numpy(key):
259
- return batch[key].cpu().numpy().squeeze()
260
-
261
- y_key = key_to_plot
262
- y_id_key = f"{key_to_plot}_id"
263
- time_utc_key = f"{key_to_plot}_time_utc"
264
- y = batch[y_key].cpu().numpy() # Select the one it is trained on
265
- y_hat = y_hat.cpu().numpy()
266
- # Select between the timesteps in timesteps to plot
267
- plotting_name = key_to_plot.upper()
268
-
269
- gsp_ids = batch[y_id_key].cpu().numpy().squeeze()
270
-
271
- times_utc = batch[time_utc_key].cpu().numpy().squeeze().astype("datetime64[ns]")
272
- times_utc = [pd.to_datetime(t) for t in times_utc]
273
- if timesteps_to_plot is not None:
274
- y = y[:, timesteps_to_plot[0] : timesteps_to_plot[1]]
275
- y_hat = y_hat[:, timesteps_to_plot[0] : timesteps_to_plot[1]]
276
- times_utc = [t[timesteps_to_plot[0] : timesteps_to_plot[1]] for t in times_utc]
277
-
278
- batch_size = y.shape[0]
279
-
280
- fig, axes = plt.subplots(4, 4, figsize=(16, 16))
281
-
282
- for i, ax in enumerate(axes.ravel()):
283
- if i >= batch_size:
284
- ax.axis("off")
285
- continue
286
- ax.plot(times_utc[i], y[i], marker=".", color="k", label=r"$y$")
287
-
288
- if quantiles is None:
289
- ax.plot(
290
- times_utc[i][-len(y_hat[i]) :], y_hat[i], marker=".", color="r", label=r"$\hat{y}$"
291
- )
292
- else:
293
- cm = pylab.get_cmap("twilight")
294
- for nq, q in enumerate(quantiles):
295
- ax.plot(
296
- times_utc[i][-len(y_hat[i]) :],
297
- y_hat[i, :, nq],
298
- color=cm(q),
299
- label=r"$\hat{y}$" + f"({q})",
300
- alpha=0.7,
301
- )
302
-
303
- ax.set_title(f"ID: {gsp_ids[i]} | {times_utc[i][0].date()}", fontsize="small")
304
-
305
- xticks = [t for t in times_utc[i] if t.minute == 0][::2]
306
- ax.set_xticks(ticks=xticks, labels=[f"{t.hour:02}" for t in xticks], rotation=90)
307
- ax.grid()
308
-
309
- axes[0, 0].legend(loc="best")
310
-
311
- for ax in axes[-1, :]:
312
- ax.set_xlabel("Time (hour of day)")
313
-
314
- if batch_idx is not None:
315
- title = f"Normed {plotting_name} output : batch_idx={batch_idx}"
316
- else:
317
- title = f"Normed {plotting_name} output"
318
- plt.suptitle(title)
319
- plt.tight_layout()
320
-
321
- return fig