peterdudfield commited on
Commit
3f81be8
·
1 Parent(s): 7bffb2f

Delete scripts

Browse files
scripts/backtest_sites.py DELETED
@@ -1,539 +0,0 @@
1
- """
2
- A script to run backtest for PVNet for specific sites
3
-
4
- Use:
5
-
6
- - This script uses hydra to construct the config, just like in `run.py`. So you need to make sure
7
- that the data config is set up appropriate for the model being run in this script
8
- - The PVNet model checkpoint; the time range over which to make predictions are made;
9
- the site ids to produce forecasts for and the output directory where the results
10
- near the top of the script as hard coded user variables. These should be changed.
11
-
12
- ```
13
- python scripts/backtest_sites.py
14
- ```
15
-
16
- """
17
-
18
- try:
19
- import torch.multiprocessing as mp
20
-
21
- mp.set_start_method("spawn", force=True)
22
- mp.set_sharing_strategy("file_system")
23
- except RuntimeError:
24
- pass
25
-
26
- import json
27
- import logging
28
- import os
29
- import sys
30
-
31
- import hydra
32
- import numpy as np
33
- import pandas as pd
34
- import torch
35
- import xarray as xr
36
- from huggingface_hub import hf_hub_download
37
- from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
38
- from ocf_data_sampler.sample.base import batch_to_tensor, copy_batch_to_device
39
- from ocf_datapipes.batch import (
40
- BatchKey,
41
- NumpyBatch,
42
- stack_np_examples_into_batch,
43
- )
44
- from ocf_datapipes.config.load import load_yaml_configuration
45
- from ocf_datapipes.load.pv.pv import OpenPVFromNetCDFIterDataPipe
46
- from ocf_datapipes.training.common import create_t0_and_loc_datapipes
47
- from ocf_datapipes.training.pvnet_site import (
48
- DictDatasetIterDataPipe,
49
- _get_datapipes_dict,
50
- construct_sliced_data_pipeline,
51
- split_dataset_dict_dp,
52
- )
53
- from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD
54
- from omegaconf import DictConfig
55
- from torch.utils.data import DataLoader, IterDataPipe, functional_datapipe
56
- from torch.utils.data.datapipes.iter import IterableWrapper
57
- from tqdm import tqdm
58
-
59
- from pvnet.load_model import get_model_from_checkpoints
60
- from pvnet.utils import SiteLocationLookup
61
-
62
- # ------------------------------------------------------------------
63
- # USER CONFIGURED VARIABLES TO RUN THE SCRIPT
64
-
65
- # Directory path to save results
66
- output_dir = "PLACEHOLDER"
67
-
68
- # Local directory to load the PVNet checkpoint from. By default this should pull the best performing
69
- # checkpoint on the val set
70
- model_chckpoint_dir = "PLACEHOLDER"
71
-
72
- hf_revision = None
73
- hf_token = None
74
- hf_model_id = None
75
-
76
- # Forecasts will be made for all available init times between these
77
- start_datetime = "2022-05-08 00:00"
78
- end_datetime = "2022-05-08 00:30"
79
-
80
- # ------------------------------------------------------------------
81
- # SET UP LOGGING
82
-
83
- logger = logging.getLogger(__name__)
84
- logging.basicConfig(stream=sys.stdout, level=logging.INFO)
85
-
86
- # ------------------------------------------------------------------
87
- # DERIVED VARIABLES
88
-
89
- # This will run on GPU if it exists
90
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
91
-
92
- # ------------------------------------------------------------------
93
- # GLOBAL VARIABLES
94
-
95
- # The frequency of the pv site data
96
- FREQ_MINS = 30
97
-
98
- # When sun as elevation below this, the forecast is set to zero
99
- MIN_DAY_ELEVATION = 0
100
-
101
- # Add all pv site ids here that you wish to produce forecasts for
102
- ALL_SITE_IDS = []
103
- # Need to be in ascending order
104
- ALL_SITE_IDS.sort()
105
-
106
- # ------------------------------------------------------------------
107
- # FUNCTIONS
108
-
109
-
110
- @functional_datapipe("pad_forward_pv")
111
- class PadForwardPVIterDataPipe(IterDataPipe):
112
- """
113
- Pads forecast pv.
114
-
115
- Sun position is calculated based off of pv time index
116
- and for t0's close to end of pv data can have wrong shape as pv starts
117
- to run out of data to slice for the forecast part.
118
- """
119
-
120
- def __init__(
121
- self,
122
- pv_dp: IterDataPipe,
123
- forecast_duration: np.timedelta64,
124
- history_duration: np.timedelta64,
125
- time_resolution_minutes: np.timedelta64,
126
- ):
127
- """Init"""
128
-
129
- super().__init__()
130
- self.pv_dp = pv_dp
131
- self.forecast_duration = forecast_duration
132
- self.history_duration = history_duration
133
- self.time_resolution_minutes = time_resolution_minutes
134
-
135
- self.min_seq_length = history_duration // time_resolution_minutes
136
-
137
- def __iter__(self):
138
- """Iter"""
139
-
140
- for xr_data in self.pv_dp:
141
- t_end = (
142
- xr_data.time_utc.data[0]
143
- + self.history_duration
144
- + self.forecast_duration
145
- + self.time_resolution_minutes
146
- )
147
- time_idx = np.arange(xr_data.time_utc.data[0], t_end, self.time_resolution_minutes)
148
-
149
- if len(xr_data.time_utc.data) < self.min_seq_length:
150
- raise ValueError("Not enough PV data to predict")
151
-
152
- yield xr_data.reindex(time_utc=time_idx, fill_value=-1)
153
-
154
-
155
- def load_model_from_hf(model_id: str, revision: str, token: str):
156
- """
157
- Loads model from HuggingFace
158
- """
159
-
160
- model_file = hf_hub_download(
161
- repo_id=model_id,
162
- filename=PYTORCH_WEIGHTS_NAME,
163
- revision=revision,
164
- token=token,
165
- )
166
-
167
- # load config file
168
- config_file = hf_hub_download(
169
- repo_id=model_id,
170
- filename=CONFIG_NAME,
171
- revision=revision,
172
- token=token,
173
- )
174
-
175
- with open(config_file, "r", encoding="utf-8") as f:
176
- config = json.load(f)
177
-
178
- model = hydra.utils.instantiate(config)
179
-
180
- state_dict = torch.load(model_file, map_location=torch.device("cuda"))
181
- model.load_state_dict(state_dict) # type: ignore
182
- model.eval() # type: ignore
183
-
184
- return model
185
-
186
-
187
- def preds_to_dataarray(preds, model, valid_times, site_ids):
188
- """Put numpy array of predictions into a dataarray"""
189
-
190
- if model.use_quantile_regression:
191
- output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles]
192
- output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw"
193
- else:
194
- output_labels = ["forecast_mw"]
195
- preds = preds[..., np.newaxis]
196
-
197
- da = xr.DataArray(
198
- data=preds,
199
- dims=["pv_system_id", "target_datetime_utc", "output_label"],
200
- coords=dict(
201
- pv_system_id=site_ids,
202
- target_datetime_utc=valid_times,
203
- output_label=output_labels,
204
- ),
205
- )
206
- return da
207
-
208
-
209
- # TODO change this to load the PV sites data (metadata?)
210
- def get_sites_ds(config_path: str) -> xr.Dataset:
211
- """Load site data from the path in the data config.
212
-
213
- Args:
214
- config_path: Path to the data configuration file
215
-
216
- Returns:
217
- xarray.Dataset of PVLive truths and capacities
218
- """
219
-
220
- config = load_yaml_configuration(config_path)
221
- site_datapipe = OpenPVFromNetCDFIterDataPipe(pv=config.input_data.pv)
222
- ds_sites = next(iter(site_datapipe))
223
-
224
- return ds_sites
225
-
226
-
227
- def get_available_t0_times(start_datetime, end_datetime, config_path):
228
- """Filter a list of t0 init-times to those for which all required input data is available.
229
-
230
- Args:
231
- start_datetime: First potential t0 time
232
- end_datetime: Last potential t0 time
233
- config_path: Path to data config file
234
-
235
- Returns:
236
- pandas.DatetimeIndex of the init-times available for required inputs
237
- """
238
-
239
- start_datetime = pd.Timestamp(start_datetime)
240
- end_datetime = pd.Timestamp(end_datetime)
241
- # Open all the input data so we can check what of the potential data init times we have input
242
- # data for
243
- datapipes_dict = _get_datapipes_dict(config_path, production=False)
244
-
245
- # Pop out the config file
246
- config = datapipes_dict.pop("config")
247
-
248
- # We are going to abuse the `create_t0_and_loc_datapipes()` function to find the init-times in
249
- # potential_init_times which we have input data for. To do this, we will feed in some fake site
250
- # data which has the potential_init_times as timestamps. This is a bit hacky but works for now
251
-
252
- # Set up init-times we would like to make predictions for
253
- potential_init_times = pd.date_range(start_datetime, end_datetime, freq=f"{FREQ_MINS}min")
254
-
255
- # We buffer the potential init-times so that we don't lose any init-times from the
256
- # start and end. Again this is a hacky step
257
- history_duration = pd.Timedelta(config.input_data.pv.history_minutes, "min")
258
- forecast_duration = pd.Timedelta(config.input_data.pv.forecast_minutes, "min")
259
- buffered_potential_init_times = pd.date_range(
260
- start_datetime - history_duration, end_datetime + forecast_duration, freq=f"{FREQ_MINS}min"
261
- )
262
- ds_fake_site = (
263
- buffered_potential_init_times.to_frame().to_xarray().rename({"index": "time_utc"})
264
- )
265
- ds_fake_site = ds_fake_site.rename({0: "site_pv_power_mw"})
266
- ds_fake_site = ds_fake_site.expand_dims("pv_system_id", axis=1)
267
- ds_fake_site = ds_fake_site.assign_coords(
268
- pv_system_id=[0],
269
- latitude=("pv_system_id", [0]),
270
- longitude=("pv_system_id", [0]),
271
- )
272
- ds_fake_site = ds_fake_site.site_pv_power_mw.astype(float) * 1e-18
273
- # Overwrite the site data which is already in the datapipes dict
274
- datapipes_dict["pv"] = IterableWrapper([ds_fake_site])
275
-
276
- # Use create_t0_and_loc_datapipes to get datapipe of init-times
277
- location_pipe, t0_datapipe = create_t0_and_loc_datapipes(
278
- datapipes_dict,
279
- configuration=config,
280
- key_for_t0="pv",
281
- shuffle=False,
282
- )
283
-
284
- # Create a full list of available init-times. Note that we need to loop over the t0s AND
285
- # locations to avoid the torch datapipes buffer overflow but we don't actually use the location
286
- available_init_times = [t0 for _, t0 in zip(location_pipe, t0_datapipe)]
287
- available_init_times = pd.to_datetime(available_init_times)
288
-
289
- logger.info(
290
- f"{len(available_init_times)} out of {len(potential_init_times)} "
291
- "requested init-times have required input data"
292
- )
293
-
294
- return available_init_times
295
-
296
-
297
- def get_loctimes_datapipes(config_path):
298
- """Create location and init-time datapipes
299
-
300
- Args:
301
- config_path: Path to data config file
302
-
303
- Returns:
304
- tuple: A tuple of datapipes
305
- - Datapipe yielding locations
306
- - Datapipe yielding init-times
307
- """
308
-
309
- # Set up ID location query object
310
- ds_sites = get_sites_ds(config_path)
311
- site_id_to_loc = SiteLocationLookup(ds_sites.longitude, ds_sites.latitude)
312
-
313
- # Filter the init-times to times we have all input data for
314
- available_target_times = get_available_t0_times(
315
- start_datetime,
316
- end_datetime,
317
- config_path,
318
- )
319
- num_t0s = len(available_target_times)
320
-
321
- # Save the init-times which predictions are being made for. This is really helpful to check
322
- # whilst the backtest is running since it takes a long time. This lets you see what init-times
323
- # the backtest will end up producing
324
- available_target_times.to_frame().to_csv(f"{output_dir}/t0_times.csv")
325
-
326
- # Cycle the site locations
327
- location_pipe = IterableWrapper([[site_id_to_loc(site_id) for site_id in ALL_SITE_IDS]]).repeat(
328
- num_t0s
329
- )
330
-
331
- # Shard and then unbatch the locations so that each worker will generate all samples for all
332
- # sites and for a single init-time
333
- location_pipe = location_pipe.sharding_filter()
334
- location_pipe = location_pipe.unbatch(
335
- unbatch_level=1
336
- ) # might not need this part since the site datapipe is creating examples
337
-
338
- # Create times datapipe so each worker receives
339
- # len(ALL_SITE_IDS) copies of the same datetime for its batch
340
- t0_datapipe = IterableWrapper(
341
- [[t0 for site_id in ALL_SITE_IDS] for t0 in available_target_times]
342
- )
343
- t0_datapipe = t0_datapipe.sharding_filter()
344
- t0_datapipe = t0_datapipe.unbatch(
345
- unbatch_level=1
346
- ) # might not need this part since the site datapipe is creating examples
347
-
348
- t0_datapipe = t0_datapipe.set_length(num_t0s * len(ALL_SITE_IDS))
349
- location_pipe = location_pipe.set_length(num_t0s * len(ALL_SITE_IDS))
350
-
351
- return location_pipe, t0_datapipe
352
-
353
-
354
- class ModelPipe:
355
- """A class to conveniently make and process predictions from batches"""
356
-
357
- def __init__(self, model, ds_site: xr.Dataset):
358
- """A class to conveniently make and process predictions from batches
359
-
360
- Args:
361
- model: PVNet site level model
362
- ds_site:xarray dataset of pv site true values and capacities
363
- """
364
- self.model = model
365
- self.ds_site = ds_site
366
-
367
- def predict_batch(self, batch: NumpyBatch) -> xr.Dataset:
368
- """Run the batch through the model and compile the predictions into an xarray DataArray
369
-
370
- Args:
371
- batch: A batch of samples with inputs for each site for the same init-time
372
-
373
- Returns:
374
- xarray.Dataset of all site and national forecasts for the batch
375
- """
376
- # Unpack some variables from the batch
377
- id0 = batch[BatchKey.pv_t0_idx]
378
-
379
- t0 = batch[BatchKey.pv_time_utc].cpu().numpy().astype("datetime64[s]")[0, id0]
380
- n_valid_times = len(batch[BatchKey.pv_time_utc][0, id0 + 1 :])
381
- model = self.model
382
-
383
- # Get valid times for this forecast
384
- valid_times = pd.to_datetime(
385
- [t0 + np.timedelta64((i + 1) * FREQ_MINS, "m") for i in range(n_valid_times)]
386
- )
387
-
388
- # Get effective capacities for this forecast
389
- site_capacities = self.ds_site.nominal_capacity_wp.values
390
- # Get the solar elevations. We need to un-normalise these from the values in the batch
391
- elevation = batch[BatchKey.pv_solar_elevation] * ELEVATION_STD + ELEVATION_MEAN
392
- # We only need elevation mask for forecasted values, not history
393
- elevation = elevation[:, id0 + 1 :]
394
-
395
- # Make mask dataset for sundown
396
- da_sundown_mask = xr.DataArray(
397
- data=elevation < MIN_DAY_ELEVATION,
398
- dims=["pv_system_id", "target_datetime_utc"],
399
- coords=dict(
400
- pv_system_id=ALL_SITE_IDS,
401
- target_datetime_utc=valid_times,
402
- ),
403
- )
404
-
405
- with torch.no_grad():
406
- # Run batch through model to get 0-1 predictions for all sites
407
- device_batch = copy_batch_to_device(batch_to_tensor(batch), device)
408
- y_normed_site = model(device_batch).detach().cpu().numpy()
409
- da_normed_site = preds_to_dataarray(y_normed_site, model, valid_times, ALL_SITE_IDS)
410
-
411
- # Multiply normalised forecasts by capacities and clip negatives
412
- da_abs_site = da_normed_site.clip(0, None) * site_capacities[:, None, None]
413
-
414
- # Apply sundown mask
415
- da_abs_site = da_abs_site.where(~da_sundown_mask).fillna(0.0)
416
-
417
- da_abs_site = da_abs_site.expand_dims(dim="init_time_utc", axis=0).assign_coords(
418
- init_time_utc=np.array([t0], dtype="datetime64[ns]")
419
- )
420
-
421
- return da_abs_site
422
-
423
-
424
- def get_datapipe(config_path: str) -> NumpyBatch:
425
- """Construct datapipe yielding batches of concurrent samples for all sites
426
-
427
- Args:
428
- config_path: Path to the data configuration file
429
-
430
- Returns:
431
- NumpyBatch: Concurrent batch of samples for each site
432
- """
433
-
434
- # Construct location and init-time datapipes
435
- location_pipe, t0_datapipe = get_loctimes_datapipes(config_path)
436
-
437
- # Get the number of init-times
438
- # num_batches = len(t0_datapipe)
439
- num_batches = len(t0_datapipe) // len(ALL_SITE_IDS)
440
- # Construct sample datapipes
441
- data_pipeline = construct_sliced_data_pipeline(
442
- config_path,
443
- location_pipe,
444
- t0_datapipe,
445
- )
446
-
447
- config = load_yaml_configuration(config_path)
448
- data_pipeline["pv"] = data_pipeline["pv"].pad_forward_pv(
449
- forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m"),
450
- history_duration=np.timedelta64(config.input_data.pv.history_minutes, "m"),
451
- time_resolution_minutes=np.timedelta64(config.input_data.pv.time_resolution_minutes, "m"),
452
- )
453
-
454
- data_pipeline = DictDatasetIterDataPipe(
455
- {k: v for k, v in data_pipeline.items() if k != "config"},
456
- ).map(split_dataset_dict_dp)
457
-
458
- data_pipeline = data_pipeline.pvnet_site_convert_to_numpy_batch()
459
-
460
- # Batch so that each worker returns a batch of all locations for a single init-time
461
- # Also convert to tensor for model
462
- data_pipeline = (
463
- data_pipeline.batch(len(ALL_SITE_IDS))
464
- .map(stack_np_examples_into_batch)
465
- .map(batch_to_tensor)
466
- )
467
- data_pipeline = data_pipeline.set_length(num_batches)
468
-
469
- return data_pipeline
470
-
471
-
472
- @hydra.main(config_path="../configs", config_name="config.yaml", version_base="1.2")
473
- def main(config: DictConfig):
474
- """Runs the backtest"""
475
-
476
- dataloader_kwargs = dict(
477
- shuffle=False,
478
- batch_size=None,
479
- sampler=None,
480
- batch_sampler=None,
481
- # Number of workers set in the config file
482
- num_workers=config.datamodule.num_workers,
483
- collate_fn=None,
484
- pin_memory=False,
485
- drop_last=False,
486
- timeout=0,
487
- worker_init_fn=None,
488
- prefetch_factor=config.datamodule.prefetch_factor,
489
- persistent_workers=False,
490
- )
491
-
492
- # Set up output dir
493
- os.makedirs(output_dir)
494
-
495
- # Create concurrent batch datapipe
496
- # Each batch includes a sample for each of the n sites for a single init-time
497
- batch_pipe = get_datapipe(config.datamodule.configuration)
498
- num_batches = len(batch_pipe)
499
- # Load the site data as an xarray object
500
- ds_site = get_sites_ds(config.datamodule.configuration)
501
- # Create a dataloader for the concurrent batches and use multiprocessing
502
- dataloader = DataLoader(batch_pipe, **dataloader_kwargs)
503
- # Load the PVNet model
504
- if model_chckpoint_dir:
505
- model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True)
506
- elif hf_model_id:
507
- model = load_model_from_hf(hf_model_id, hf_revision, hf_token)
508
- else:
509
- raise ValueError("Provide a model checkpoint or a HuggingFace model")
510
-
511
- model = model.eval().to(device)
512
-
513
- # Create object to make predictions for each input batch
514
- model_pipe = ModelPipe(model, ds_site)
515
- # Loop through the batches
516
- pbar = tqdm(total=num_batches)
517
- for i, batch in zip(range(num_batches), dataloader):
518
- try:
519
- # Make predictions for the init-time
520
- ds_abs_all = model_pipe.predict_batch(batch)
521
-
522
- t0 = ds_abs_all.init_time_utc.values[0]
523
-
524
- # Save the predictions
525
- filename = f"{output_dir}/{t0}.nc"
526
- ds_abs_all.to_netcdf(filename)
527
-
528
- pbar.update()
529
- except Exception as e:
530
- print(f"Exception {e} at batch {i}")
531
- pass
532
-
533
- # Close down
534
- pbar.close()
535
- del dataloader
536
-
537
-
538
- if __name__ == "__main__":
539
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/backtest_uk_gsp.py DELETED
@@ -1,431 +0,0 @@
1
- """
2
- A script to run backtest for PVNet and the summation model for UK regional and national
3
-
4
- Use:
5
-
6
- - This script uses hydra to construct the config, just like in `run.py`. So you need to make sure
7
- that the data config is set up appropriate for the model being run in this script
8
- - The PVNet and summation model checkpoints; the time range over which to make predictions are made;
9
- and the output directory where the results near the top of the script as hard coded user
10
- variables. These should be changed.
11
-
12
-
13
- ```
14
- python backtest_uk_gsp.py
15
- ```
16
-
17
- """
18
-
19
- try:
20
- import torch.multiprocessing as mp
21
-
22
- mp.set_start_method("spawn", force=True)
23
- mp.set_sharing_strategy("file_system")
24
- except RuntimeError:
25
- pass
26
-
27
- import logging
28
- import os
29
- import sys
30
-
31
- import hydra
32
- import numpy as np
33
- import pandas as pd
34
- import torch
35
- import xarray as xr
36
- from ocf_data_sampler.sample.base import batch_to_tensor, copy_batch_to_device
37
- from ocf_datapipes.batch import (
38
- BatchKey,
39
- NumpyBatch,
40
- )
41
- from ocf_datapipes.config.load import load_yaml_configuration
42
- from ocf_datapipes.load import OpenGSP
43
- from ocf_datapipes.training.common import _get_datapipes_dict
44
- from ocf_datapipes.training.pvnet_all_gsp import construct_sliced_data_pipeline, create_t0_datapipe
45
- from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD
46
- from omegaconf import DictConfig
47
-
48
- # TODO: Having this script rely on pvnet_app sets up a circular dependency. The function
49
- # `preds_to_dataarray()` should probably be moved here
50
- from pvnet_app.utils import preds_to_dataarray
51
- from torch.utils.data import DataLoader
52
- from torch.utils.data.datapipes.iter import IterableWrapper
53
- from tqdm import tqdm
54
-
55
- from pvnet.load_model import get_model_from_checkpoints
56
-
57
- # ------------------------------------------------------------------
58
- # USER CONFIGURED VARIABLES
59
- output_dir = "/mnt/disks/extra_batches/test_backtest"
60
-
61
- # Local directory to load the PVNet checkpoint from. By default this should pull the best performing
62
- # checkpoint on the val set
63
- model_chckpoint_dir = "/home/jamesfulton/repos/PVNet/checkpoints/q911tei5"
64
-
65
- # Local directory to load the summation model checkpoint from. By default this should pull the best
66
- # performing checkpoint on the val set. If set to None a simple sum is used instead
67
- summation_chckpoint_dir = (
68
- "/home/jamesfulton/repos/PVNet_summation/checkpoints/pvnet_summation/73oa4w9t"
69
- )
70
-
71
- # Forecasts will be made for all available init times between these
72
- start_datetime = "2022-05-08 00:00"
73
- end_datetime = "2022-05-08 00:30"
74
-
75
- # ------------------------------------------------------------------
76
- # SET UP LOGGING
77
-
78
- logger = logging.getLogger(__name__)
79
- logging.basicConfig(stream=sys.stdout, level=logging.INFO)
80
-
81
- # ------------------------------------------------------------------
82
- # DERIVED VARIABLES
83
-
84
- # This will run on GPU if it exists
85
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86
-
87
- # ------------------------------------------------------------------
88
- # GLOBAL VARIABLES
89
-
90
- # The frequency of the GSP data
91
- FREQ_MINS = 30
92
-
93
- # When sun as elevation below this, the forecast is set to zero
94
- MIN_DAY_ELEVATION = 0
95
-
96
- # All regional GSP IDs - not including national which is treated separately
97
- ALL_GSP_IDS = np.arange(1, 318)
98
-
99
- # ------------------------------------------------------------------
100
- # FUNCTIONS
101
-
102
-
103
- def get_gsp_ds(config_path: str) -> xr.Dataset:
104
- """Load GSP data from the path in the data config.
105
-
106
- Args:
107
- config_path: Path to the data configuration file
108
-
109
- Returns:
110
- xarray.Dataset of PVLive truths and capacities
111
- """
112
-
113
- config = load_yaml_configuration(config_path)
114
- gsp_datapipe = OpenGSP(gsp_pv_power_zarr_path=config.input_data.gsp.gsp_zarr_path)
115
- ds_gsp = next(iter(gsp_datapipe))
116
-
117
- return ds_gsp
118
-
119
-
120
- def get_available_t0_times(start_datetime, end_datetime, config_path):
121
- """Filter a list of t0 init-times to those for which all required input data is available.
122
-
123
- Args:
124
- start_datetime: First potential t0 time
125
- end_datetime: Last potential t0 time
126
- config_path: Path to data config file
127
-
128
- Returns:
129
- pandas.DatetimeIndex of the init-times available for required inputs
130
- """
131
-
132
- start_datetime = pd.Timestamp(start_datetime)
133
- end_datetime = pd.Timestamp(end_datetime)
134
- # Open all the input data so we can check what of the potential data init times we have input
135
- # data for
136
- datapipes_dict = _get_datapipes_dict(config_path, production=False)
137
-
138
- # Pop out the config file
139
- config = datapipes_dict.pop("config")
140
-
141
- # We are going to abuse the `create_t0_datapipe()` function to find the init-times in
142
- # potential_init_times which we have input data for. To do this, we will feed in some fake GSP
143
- # data which has the potential_init_times as timestamps. This is a bit hacky but works for now
144
-
145
- # Set up init-times we would like to make predictions for
146
- potential_init_times = pd.date_range(start_datetime, end_datetime, freq=f"{FREQ_MINS}min")
147
-
148
- # We buffer the potential init-times so that we don't lose any init-times from the
149
- # start and end. Again this is a hacky step
150
- history_duration = pd.Timedelta(config.input_data.gsp.history_minutes, "min")
151
- forecast_duration = pd.Timedelta(config.input_data.gsp.forecast_minutes, "min")
152
- buffered_potential_init_times = pd.date_range(
153
- start_datetime - history_duration, end_datetime + forecast_duration, freq=f"{FREQ_MINS}min"
154
- )
155
-
156
- ds_fake_gsp = buffered_potential_init_times.to_frame().to_xarray().rename({"index": "time_utc"})
157
- ds_fake_gsp = ds_fake_gsp.rename({0: "gsp_pv_power_mw"})
158
- ds_fake_gsp = ds_fake_gsp.expand_dims("gsp_id", axis=1)
159
- ds_fake_gsp = ds_fake_gsp.assign_coords(
160
- gsp_id=[0],
161
- x_osgb=("gsp_id", [0]),
162
- y_osgb=("gsp_id", [0]),
163
- )
164
- ds_fake_gsp = ds_fake_gsp.gsp_pv_power_mw.astype(float) * 1e-18
165
-
166
- # Overwrite the GSP data which is already in the datapipes dict
167
- datapipes_dict["gsp"] = IterableWrapper([ds_fake_gsp])
168
-
169
- # Use create_t0_datapipe to get datapipe of init-times
170
- t0_datapipe = create_t0_datapipe(
171
- datapipes_dict,
172
- configuration=config,
173
- shuffle=False,
174
- )
175
-
176
- # Create a full list of available init-times
177
- available_init_times = pd.to_datetime([t0 for t0 in t0_datapipe])
178
-
179
- logger.info(
180
- f"{len(available_init_times)} out of {len(potential_init_times)} "
181
- "requested init-times have required input data"
182
- )
183
-
184
- return available_init_times
185
-
186
-
187
- def get_times_datapipe(config_path):
188
- """Create init-time datapipe
189
-
190
- Args:
191
- config_path: Path to data config file
192
-
193
- Returns:
194
- Datapipe: A Datapipe yielding init-times
195
- """
196
-
197
- # Filter the init-times to times we have all input data for
198
- available_target_times = get_available_t0_times(
199
- start_datetime,
200
- end_datetime,
201
- config_path,
202
- )
203
- num_t0s = len(available_target_times)
204
-
205
- # Save the init-times which predictions are being made for. This is really helpful to check
206
- # whilst the backtest is running since it takes a long time. This lets you see what init-times
207
- # the backtest will end up producing
208
- available_target_times.to_frame().to_csv(f"{output_dir}/t0_times.csv")
209
-
210
- # Create times datapipe so each worker receives 317 copies of the same datetime for its batch
211
- t0_datapipe = IterableWrapper(available_target_times)
212
- t0_datapipe = t0_datapipe.sharding_filter()
213
-
214
- t0_datapipe = t0_datapipe.set_length(num_t0s)
215
-
216
- return t0_datapipe
217
-
218
-
219
- class ModelPipe:
220
- """A class to conveniently make and process predictions from batches"""
221
-
222
- def __init__(self, model, summation_model, ds_gsp: xr.Dataset):
223
- """A class to conveniently make and process predictions from batches
224
-
225
- Args:
226
- model: PVNet GSP level model
227
- summation_model: Summation model to make national forecast from GSP level forecasts
228
- ds_gsp:xarray dataset of PVLive true values and capacities
229
- """
230
- self.model = model
231
- self.summation_model = summation_model
232
- self.ds_gsp = ds_gsp
233
-
234
- def predict_batch(self, batch: NumpyBatch) -> xr.Dataset:
235
- """Run the batch through the model and compile the predictions into an xarray DataArray
236
-
237
- Args:
238
- batch: A batch of samples with inputs for each GSP for the same init-time
239
-
240
- Returns:
241
- xarray.Dataset of all GSP and national forecasts for the batch
242
- """
243
-
244
- # Unpack some variables from the batch
245
- id0 = batch[BatchKey.gsp_t0_idx]
246
- t0 = batch[BatchKey.gsp_time_utc].cpu().numpy().astype("datetime64[s]")[0, id0]
247
- n_valid_times = len(batch[BatchKey.gsp_time_utc][0, id0 + 1 :])
248
- ds_gsp = self.ds_gsp
249
- model = self.model
250
- summation_model = self.summation_model
251
-
252
- # Get valid times for this forecast
253
- valid_times = pd.to_datetime(
254
- [t0 + np.timedelta64((i + 1) * FREQ_MINS, "m") for i in range(n_valid_times)]
255
- )
256
-
257
- # Get effective capacities for this forecast
258
- gsp_capacities = ds_gsp.effective_capacity_mwp.sel(
259
- time_utc=t0, gsp_id=slice(1, None)
260
- ).values
261
- national_capacity = ds_gsp.effective_capacity_mwp.sel(time_utc=t0, gsp_id=0).item()
262
-
263
- # Get the solar elevations. We need to un-normalise these from the values in the batch
264
- elevation = batch[BatchKey.gsp_solar_elevation] * ELEVATION_STD + ELEVATION_MEAN
265
- # We only need elevation mask for forecasted values, not history
266
- elevation = elevation[:, id0 + 1 :]
267
-
268
- # Make mask dataset for sundown
269
- da_sundown_mask = xr.DataArray(
270
- data=elevation < MIN_DAY_ELEVATION,
271
- dims=["gsp_id", "target_datetime_utc"],
272
- coords=dict(
273
- gsp_id=ALL_GSP_IDS,
274
- target_datetime_utc=valid_times,
275
- ),
276
- )
277
-
278
- with torch.no_grad():
279
- # Run batch through model to get 0-1 predictions for all GSPs
280
- device_batch = copy_batch_to_device(batch_to_tensor(batch), device)
281
- y_normed_gsp = model(device_batch).detach().cpu().numpy()
282
-
283
- da_normed_gsp = preds_to_dataarray(y_normed_gsp, model, valid_times, ALL_GSP_IDS)
284
-
285
- # Multiply normalised forecasts by capacities and clip negatives
286
- da_abs_gsp = da_normed_gsp.clip(0, None) * gsp_capacities[:, None, None]
287
-
288
- # Apply sundown mask
289
- da_abs_gsp = da_abs_gsp.where(~da_sundown_mask).fillna(0.0)
290
-
291
- # Make national predictions using summation model
292
- if summation_model is not None:
293
- with torch.no_grad():
294
- # Construct sample for the summation model
295
- summation_inputs = {
296
- "pvnet_outputs": torch.Tensor(y_normed_gsp[np.newaxis]).to(device),
297
- "effective_capacity": (
298
- torch.Tensor(gsp_capacities / national_capacity)
299
- .to(device)
300
- .unsqueeze(0)
301
- .unsqueeze(-1)
302
- ),
303
- }
304
-
305
- # Run batch through the summation model
306
- y_normed_national = (
307
- summation_model(summation_inputs).detach().squeeze().cpu().numpy()
308
- )
309
-
310
- # Convert national predictions to DataArray
311
- da_normed_national = preds_to_dataarray(
312
- y_normed_national[np.newaxis], summation_model, valid_times, gsp_ids=[0]
313
- )
314
-
315
- # Multiply normalised forecasts by capacities and clip negatives
316
- da_abs_national = da_normed_national.clip(0, None) * national_capacity
317
-
318
- # Apply sundown mask - All GSPs must be masked to mask national
319
- da_abs_national = da_abs_national.where(~da_sundown_mask.all(dim="gsp_id")).fillna(0.0)
320
-
321
- # If no summation model, make national predictions using simple sum
322
- else:
323
- da_abs_national = (
324
- da_abs_gsp.sum(dim="gsp_id")
325
- .expand_dims(dim="gsp_id", axis=0)
326
- .assign_coords(gsp_id=[0])
327
- )
328
-
329
- # Concat the regional GSP and national predictions
330
- da_abs_all = xr.concat([da_abs_national, da_abs_gsp], dim="gsp_id")
331
- ds_abs_all = da_abs_all.to_dataset(name="hindcast")
332
-
333
- ds_abs_all = ds_abs_all.expand_dims(dim="init_time_utc", axis=0).assign_coords(
334
- init_time_utc=[t0]
335
- )
336
-
337
- return ds_abs_all
338
-
339
-
340
- def get_datapipe(config_path: str) -> NumpyBatch:
341
- """Construct datapipe yielding batches of concurrent samples for all GSPs
342
-
343
- Args:
344
- config_path: Path to the data configuration file
345
-
346
- Returns:
347
- NumpyBatch: Concurrent batch of samples for each GSP
348
- """
349
-
350
- # Construct location and init-time datapipes
351
- t0_datapipe = get_times_datapipe(config_path)
352
-
353
- # Construct sample datapipes
354
- data_pipeline = construct_sliced_data_pipeline(
355
- config_path,
356
- t0_datapipe,
357
- )
358
-
359
- # Convert to tensor for model
360
- data_pipeline = data_pipeline.map(batch_to_tensor).set_length(len(t0_datapipe))
361
-
362
- return data_pipeline
363
-
364
-
365
- @hydra.main(config_path="../configs", config_name="config.yaml", version_base="1.2")
366
- def main(config: DictConfig):
367
- """Runs the backtest"""
368
-
369
- dataloader_kwargs = dict(
370
- shuffle=False,
371
- batch_size=None,
372
- sampler=None,
373
- batch_sampler=None,
374
- # Number of workers set in the config file
375
- num_workers=config.datamodule.num_workers,
376
- collate_fn=None,
377
- pin_memory=False,
378
- drop_last=False,
379
- timeout=0,
380
- worker_init_fn=None,
381
- prefetch_factor=config.datamodule.prefetch_factor,
382
- persistent_workers=False,
383
- )
384
-
385
- # Set up output dir
386
- os.makedirs(output_dir)
387
-
388
- # Create concurrent batch datapipe
389
- # Each batch includes a sample for each of the 317 GSPs for a single init-time
390
- batch_pipe = get_datapipe(config.datamodule.configuration)
391
- num_batches = len(batch_pipe)
392
-
393
- # Load the GSP data as an xarray object
394
- ds_gsp = get_gsp_ds(config.datamodule.configuration)
395
-
396
- # Create a dataloader for the concurrent batches and use multiprocessing
397
- dataloader = DataLoader(batch_pipe, **dataloader_kwargs)
398
-
399
- # Load the PVNet model and summation model
400
- model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True)
401
- model = model.eval().to(device)
402
- if summation_chckpoint_dir is None:
403
- summation_model = None
404
- else:
405
- summation_model, *_ = get_model_from_checkpoints([summation_chckpoint_dir], val_best=True)
406
- summation_model = summation_model.eval().to(device)
407
-
408
- # Create object to make predictions for each input batch
409
- model_pipe = ModelPipe(model, summation_model, ds_gsp)
410
-
411
- # Loop through the batches
412
- pbar = tqdm(total=num_batches)
413
- for i, batch in zip(range(num_batches), dataloader):
414
- # Make predictions for the init-time
415
- ds_abs_all = model_pipe.predict_batch(batch)
416
-
417
- t0 = ds_abs_all.init_time_utc.values[0]
418
-
419
- # Save the predictioons
420
- filename = f"{output_dir}/{t0}.nc"
421
- ds_abs_all.to_netcdf(filename)
422
-
423
- pbar.update()
424
-
425
- # Close down
426
- pbar.close()
427
- del dataloader
428
-
429
-
430
- if __name__ == "__main__":
431
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/checkpoint_to_huggingface.py DELETED
@@ -1,83 +0,0 @@
1
- """Command line tool to push locally save model checkpoints to huggingface
2
-
3
- use:
4
- python checkpoint_to_huggingface.py "path/to/model/checkpoints" \
5
- --huggingface-repo="openclimatefix/pvnet_uk_region" \
6
- --wandb-repo="openclimatefix/pvnet2.1" \
7
- --local-path="~/tmp/this_model" \
8
- --no-push-to-hub
9
- """
10
-
11
- import tempfile
12
-
13
- import typer
14
- import wandb
15
-
16
- from pvnet.load_model import get_model_from_checkpoints
17
-
18
- app = typer.Typer(pretty_exceptions_show_locals=False)
19
-
20
- @app.command()
21
- def push_to_huggingface(
22
- checkpoint_dir_paths: list[str],
23
- huggingface_repo: str = "openclimatefix/pvnet_uk_region", # e.g. openclimatefix/windnet_india
24
- wandb_repo: str = "openclimatefix/pvnet2.1",
25
- val_best: bool = True,
26
- wandb_ids: list[str] = [],
27
- local_path: str = None,
28
- push_to_hub: bool = True,
29
- ):
30
- """Push a local model to a huggingface model repo
31
-
32
- Args:
33
- checkpoint_dir_paths: Path(s) of the checkpoint directory(ies)
34
- huggingface_repo: Name of the HuggingFace repo to push the model to
35
- wandb_repo: Name of the wandb repo which has training logs
36
- val_best: Use best model according to val loss, else last saved model
37
- wandb_ids: The wandb ID code(s)
38
- local_path: Where to save the local copy of the model
39
- push_to_hub: Whether to push the model to the hub or just create local version.
40
- """
41
-
42
- assert push_to_hub or local_path is not None
43
-
44
- is_ensemble = len(checkpoint_dir_paths) > 1
45
-
46
- # Check if checkpoint dir name is wandb run ID
47
- if wandb_ids == []:
48
- all_wandb_ids = [run.id for run in wandb.Api().runs(path=wandb_repo)]
49
- for path in checkpoint_dir_paths:
50
- dirname = path.split("/")[-1]
51
- if dirname in all_wandb_ids:
52
- wandb_ids.append(dirname)
53
- else:
54
- wandb_ids.append(None)
55
-
56
- model, model_config, data_config = get_model_from_checkpoints(checkpoint_dir_paths, val_best)
57
-
58
- if not is_ensemble:
59
- wandb_ids = wandb_ids[0]
60
-
61
- # Push to hub
62
- if local_path is None:
63
- temp_dir = tempfile.TemporaryDirectory()
64
- model_output_dir = temp_dir.name
65
- else:
66
- model_output_dir = local_path
67
-
68
- model.save_pretrained(
69
- model_output_dir,
70
- config=model_config,
71
- data_config=data_config,
72
- wandb_repo=wandb_repo,
73
- wandb_ids=wandb_ids,
74
- push_to_hub=push_to_hub,
75
- repo_id=huggingface_repo if push_to_hub else None,
76
- )
77
-
78
- if local_path is None:
79
- temp_dir.cleanup()
80
-
81
-
82
- if __name__ == "__main__":
83
- app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/save_concurrent_samples.py DELETED
@@ -1,189 +0,0 @@
1
- """
2
- Constructs batches where each batch includes all GSPs and only a single timestamp.
3
-
4
- Currently a slightly hacky implementation due to the way the configs are done. This script will use
5
- the same config file currently set to train the model. In the datamodule config it is possible
6
- to set the batch_output_dir and number of train/val batches, they can also be overriden in the
7
- command as shown in the example below.
8
-
9
- use:
10
- ```
11
- python save_concurrent_samples.py \
12
- +datamodule.sample_output_dir="/mnt/disks/concurrent_batches/concurrent_samples_sat_pred_test" \
13
- +datamodule.num_train_samples=20 \
14
- +datamodule.num_val_samples=20
15
- ```
16
-
17
- """
18
- # Ensure this block of code runs only in the main process to avoid issues with worker processes.
19
- if __name__ == "__main__":
20
- import torch.multiprocessing as mp
21
-
22
- # Set the start method for torch multiprocessing. Choose either "forkserver" or "spawn" to be
23
- # compatible with dask's multiprocessing.
24
- mp.set_start_method("forkserver")
25
-
26
- # Set the sharing strategy to 'file_system' to handle file descriptor limitations. This is
27
- # important because libraries like Zarr may open many files, which can exhaust the file
28
- # descriptor limit if too many workers are used.
29
- mp.set_sharing_strategy("file_system")
30
-
31
-
32
- import logging
33
- import os
34
- import shutil
35
- import sys
36
- import warnings
37
-
38
- import hydra
39
- import numpy as np
40
- import torch
41
- from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKConcurrentDataset
42
- from omegaconf import DictConfig, OmegaConf
43
- from sqlalchemy import exc as sa_exc
44
- from torch.utils.data import DataLoader, Dataset
45
- from tqdm import tqdm
46
-
47
- from pvnet.utils import print_config
48
-
49
- # ------- filter warning and set up config -------
50
-
51
- warnings.filterwarnings("ignore", category=sa_exc.SAWarning)
52
-
53
- logger = logging.getLogger(__name__)
54
-
55
- logging.basicConfig(stream=sys.stdout, level=logging.ERROR)
56
-
57
- # -------------------------------------------------
58
-
59
-
60
- class SaveFuncFactory:
61
- """Factory for creating a function to save a sample to disk."""
62
-
63
- def __init__(self, save_dir: str):
64
- """Factory for creating a function to save a sample to disk."""
65
- self.save_dir = save_dir
66
-
67
- def __call__(self, sample, sample_num: int):
68
- """Save a sample to disk"""
69
- torch.save(sample, f"{self.save_dir}/{sample_num:08}.pt")
70
-
71
-
72
- def save_samples_with_dataloader(
73
- dataset: Dataset,
74
- save_dir: str,
75
- num_samples: int,
76
- dataloader_kwargs: dict,
77
- ) -> None:
78
- """Save samples from a dataset using a dataloader."""
79
- save_func = SaveFuncFactory(save_dir)
80
-
81
- gsp_ids = np.array([loc.id for loc in dataset.locations])
82
-
83
- dataloader = DataLoader(dataset, **dataloader_kwargs)
84
-
85
- pbar = tqdm(total=num_samples)
86
- for i, sample in zip(range(num_samples), dataloader):
87
- check_sample(sample, gsp_ids)
88
- save_func(sample, i)
89
- pbar.update()
90
- pbar.close()
91
-
92
-
93
- def check_sample(sample, gsp_ids):
94
- """Check if sample is valid concurrent batch for all GSPs"""
95
- # Check all GSP IDs are included and in correct order
96
- assert (sample["gsp_id"].flatten().numpy() == gsp_ids).all()
97
- # Check all times are the same
98
- assert len(np.unique(sample["gsp_time_utc"][:, 0].numpy())) == 1
99
-
100
-
101
- @hydra.main(config_path="../configs/", config_name="config.yaml", version_base="1.2")
102
- def main(config: DictConfig) -> None:
103
- """Constructs and saves validation and training samples."""
104
- config_dm = config.datamodule
105
-
106
- print_config(config, resolve=False)
107
-
108
- # Set up directory
109
- os.makedirs(config_dm.sample_output_dir, exist_ok=False)
110
-
111
- # Copy across configs which define the samples into the new sample directory
112
- with open(f"{config_dm.sample_output_dir}/datamodule.yaml", "w") as f:
113
- f.write(OmegaConf.to_yaml(config_dm))
114
-
115
- shutil.copyfile(
116
- config_dm.configuration, f"{config_dm.sample_output_dir}/data_configuration.yaml"
117
- )
118
-
119
- # Define the keywargs going into the train and val dataloaders
120
- dataloader_kwargs = dict(
121
- shuffle=True,
122
- batch_size=None,
123
- sampler=None,
124
- batch_sampler=None,
125
- num_workers=config_dm.num_workers,
126
- collate_fn=None,
127
- pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial
128
- drop_last=False,
129
- timeout=0,
130
- worker_init_fn=None,
131
- prefetch_factor=config_dm.prefetch_factor,
132
- persistent_workers=False, # Not needed since we only enter the dataloader loop once
133
- )
134
-
135
- if config_dm.num_val_samples > 0:
136
- print("----- Saving val samples -----")
137
-
138
- val_output_dir = f"{config_dm.sample_output_dir}/val"
139
-
140
- # Make directory for val samples
141
- os.mkdir(val_output_dir)
142
-
143
- # Get the dataset
144
- val_dataset = PVNetUKConcurrentDataset(
145
- config_dm.configuration,
146
- start_time=config_dm.val_period[0],
147
- end_time=config_dm.val_period[1],
148
- )
149
-
150
- # Save samples
151
- save_samples_with_dataloader(
152
- dataset=val_dataset,
153
- save_dir=val_output_dir,
154
- num_samples=config_dm.num_val_samples,
155
- dataloader_kwargs=dataloader_kwargs,
156
- )
157
-
158
- del val_dataset
159
-
160
- if config_dm.num_train_samples > 0:
161
- print("----- Saving train samples -----")
162
-
163
- train_output_dir = f"{config_dm.sample_output_dir}/train"
164
-
165
- # Make directory for train samples
166
- os.mkdir(train_output_dir)
167
-
168
- # Get the dataset
169
- train_dataset = PVNetUKConcurrentDataset(
170
- config_dm.configuration,
171
- start_time=config_dm.train_period[0],
172
- end_time=config_dm.train_period[1],
173
- )
174
-
175
- # Save samples
176
- save_samples_with_dataloader(
177
- dataset=train_dataset,
178
- save_dir=train_output_dir,
179
- num_samples=config_dm.num_train_samples,
180
- dataloader_kwargs=dataloader_kwargs,
181
- )
182
-
183
- del train_dataset
184
-
185
- print("----- Saving complete -----")
186
-
187
-
188
- if __name__ == "__main__":
189
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/save_samples.py DELETED
@@ -1,218 +0,0 @@
1
- """
2
- Constructs samples and saves them to disk.
3
-
4
- Currently a slightly hacky implementation due to the way the configs are done. This script will use
5
- the same config file currently set to train the model.
6
-
7
- use:
8
- ```
9
- python save_samples.py
10
- ```
11
- if setting all values in the datamodule config file, or
12
-
13
- ```
14
- python save_samples.py \
15
- +datamodule.sample_output_dir="/mnt/disks/bigbatches/samples_v0" \
16
- +datamodule.num_train_samples=0 \
17
- +datamodule.num_val_samples=2 \
18
- datamodule.num_workers=2 \
19
- datamodule.prefetch_factor=2
20
- ```
21
- if wanting to override these values for example
22
- """
23
-
24
- # Ensure this block of code runs only in the main process to avoid issues with worker processes.
25
- if __name__ == "__main__":
26
- import torch.multiprocessing as mp
27
-
28
- # Set the start method for torch multiprocessing. Choose either "forkserver" or "spawn" to be
29
- # compatible with dask's multiprocessing.
30
- mp.set_start_method("forkserver")
31
-
32
- # Set the sharing strategy to 'file_system' to handle file descriptor limitations. This is
33
- # important because libraries like Zarr may open many files, which can exhaust the file
34
- # descriptor limit if too many workers are used.
35
- mp.set_sharing_strategy("file_system")
36
-
37
-
38
- import logging
39
- import os
40
- import shutil
41
- import sys
42
- import warnings
43
-
44
- import dask
45
- import hydra
46
- from ocf_data_sampler.torch_datasets.datasets import PVNetUKRegionalDataset, SitesDataset
47
- from ocf_data_sampler.torch_datasets.sample.site import SiteSample
48
- from ocf_data_sampler.torch_datasets.sample.uk_regional import UKRegionalSample
49
- from omegaconf import DictConfig, OmegaConf
50
- from sqlalchemy import exc as sa_exc
51
- from torch.utils.data import DataLoader, Dataset
52
- from tqdm import tqdm
53
-
54
- from pvnet.utils import print_config
55
-
56
- dask.config.set(scheduler="threads", num_workers=4)
57
-
58
-
59
- # ------- filter warning and set up config -------
60
-
61
- warnings.filterwarnings("ignore", category=sa_exc.SAWarning)
62
-
63
- logger = logging.getLogger(__name__)
64
-
65
- logging.basicConfig(stream=sys.stdout, level=logging.ERROR)
66
-
67
- # -------------------------------------------------
68
-
69
-
70
- class SaveFuncFactory:
71
- """Factory for creating a function to save a sample to disk."""
72
-
73
- def __init__(self, save_dir: str, renewable: str = "pv_uk"):
74
- """Factory for creating a function to save a sample to disk."""
75
- self.save_dir = save_dir
76
- self.renewable = renewable
77
-
78
- def __call__(self, sample, sample_num: int):
79
- """Save a sample to disk"""
80
- save_path = f"{self.save_dir}/{sample_num:08}"
81
-
82
- if self.renewable == "pv_uk":
83
- sample_class = UKRegionalSample(sample)
84
- filename = f"{save_path}.pt"
85
- elif self.renewable == "site":
86
- sample_class = SiteSample(sample)
87
- filename = f"{save_path}.nc"
88
- else:
89
- raise ValueError(f"Unknown renewable: {self.renewable}")
90
- # Assign data and save
91
- sample_class._data = sample
92
- sample_class.save(filename)
93
-
94
-
95
- def get_dataset(
96
- config_path: str, start_time: str, end_time: str, renewable: str = "pv_uk"
97
- ) -> Dataset:
98
- """Get the dataset for the given renewable type."""
99
- if renewable == "pv_uk":
100
- dataset_cls = PVNetUKRegionalDataset
101
- elif renewable == "site":
102
- dataset_cls = SitesDataset
103
- else:
104
- raise ValueError(f"Unknown renewable: {renewable}")
105
-
106
- return dataset_cls(config_path, start_time=start_time, end_time=end_time)
107
-
108
-
109
- def save_samples_with_dataloader(
110
- dataset: Dataset,
111
- save_dir: str,
112
- num_samples: int,
113
- dataloader_kwargs: dict,
114
- renewable: str = "pv_uk",
115
- ) -> None:
116
- """Save samples from a dataset using a dataloader."""
117
- save_func = SaveFuncFactory(save_dir, renewable=renewable)
118
-
119
- dataloader = DataLoader(dataset, **dataloader_kwargs)
120
-
121
- pbar = tqdm(total=num_samples)
122
- for i, sample in zip(range(num_samples), dataloader):
123
- save_func(sample, i)
124
- pbar.update()
125
- pbar.close()
126
-
127
-
128
- @hydra.main(config_path="../configs/", config_name="config.yaml", version_base="1.2")
129
- def main(config: DictConfig) -> None:
130
- """Constructs and saves validation and training samples."""
131
- config_dm = config.datamodule
132
-
133
- print_config(config, resolve=False)
134
-
135
- # Set up directory
136
- os.makedirs(config_dm.sample_output_dir, exist_ok=False)
137
-
138
- # Copy across configs which define the samples into the new sample directory
139
- with open(f"{config_dm.sample_output_dir}/datamodule.yaml", "w") as f:
140
- f.write(OmegaConf.to_yaml(config_dm))
141
-
142
- shutil.copyfile(
143
- config_dm.configuration, f"{config_dm.sample_output_dir}/data_configuration.yaml"
144
- )
145
-
146
- # Define the keywargs going into the train and val dataloaders
147
- dataloader_kwargs = dict(
148
- shuffle=True,
149
- batch_size=None,
150
- sampler=None,
151
- batch_sampler=None,
152
- num_workers=config_dm.num_workers,
153
- collate_fn=None,
154
- pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial
155
- drop_last=False,
156
- timeout=0,
157
- worker_init_fn=None,
158
- prefetch_factor=config_dm.prefetch_factor,
159
- persistent_workers=False, # Not needed since we only enter the dataloader loop once
160
- )
161
-
162
- if config_dm.num_val_samples > 0:
163
- print("----- Saving val samples -----")
164
-
165
- val_output_dir = f"{config_dm.sample_output_dir}/val"
166
-
167
- # Make directory for val samples
168
- os.mkdir(val_output_dir)
169
-
170
- # Get the dataset
171
- val_dataset = get_dataset(
172
- config_dm.configuration,
173
- *config_dm.val_period,
174
- renewable=config.renewable,
175
- )
176
-
177
- # Save samples
178
- save_samples_with_dataloader(
179
- dataset=val_dataset,
180
- save_dir=val_output_dir,
181
- num_samples=config_dm.num_val_samples,
182
- dataloader_kwargs=dataloader_kwargs,
183
- renewable=config.renewable,
184
- )
185
-
186
- del val_dataset
187
-
188
- if config_dm.num_train_samples > 0:
189
- print("----- Saving train samples -----")
190
-
191
- train_output_dir = f"{config_dm.sample_output_dir}/train"
192
-
193
- # Make directory for train samples
194
- os.mkdir(train_output_dir)
195
-
196
- # Get the dataset
197
- train_dataset = get_dataset(
198
- config_dm.configuration,
199
- *config_dm.train_period,
200
- renewable=config.renewable,
201
- )
202
-
203
- # Save samples
204
- save_samples_with_dataloader(
205
- dataset=train_dataset,
206
- save_dir=train_output_dir,
207
- num_samples=config_dm.num_train_samples,
208
- dataloader_kwargs=dataloader_kwargs,
209
- renewable=config.renewable,
210
- )
211
-
212
- del train_dataset
213
-
214
- print("----- Saving complete -----")
215
-
216
-
217
- if __name__ == "__main__":
218
- main()