File size: 1,920 Bytes
c4b87d2 0a58567 c4b87d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import numpy as np
from src.data.containers import TimeSeriesContainer
from src.synthetic_generation.abstract_classes import GeneratorWrapper
from src.synthetic_generation.generator_params import StepGeneratorParams
from src.synthetic_generation.steps.step_generator import StepGenerator
class StepGeneratorWrapper(GeneratorWrapper):
"""
Wrapper for StepGenerator that handles batch generation and formatting.
"""
def __init__(self, params: StepGeneratorParams):
"""
Initialize the StepGeneratorWrapper.
Parameters
----------
params : StepGeneratorParams
Parameters for the step generator.
"""
super().__init__(params)
self.generator = StepGenerator(params)
def generate_batch(self, batch_size: int, seed: int | None = None) -> TimeSeriesContainer:
"""
Generate a batch of step function time series.
Parameters
----------
batch_size : int
Number of time series to generate.
seed : int, optional
Random seed for reproducibility.
Returns
-------
TimeSeriesContainer
TimeSeriesContainer containing the generated time series.
"""
if seed is not None:
self._set_random_seeds(seed)
# Sample parameters for the batch
sampled_params = self._sample_parameters(batch_size)
# Generate time series
values = []
for i in range(batch_size):
# Use a different seed for each series in the batch
series_seed = (seed + i) if seed is not None else None
series = self.generator.generate_time_series(series_seed)
values.append(series)
return TimeSeriesContainer(
values=np.array(values),
start=sampled_params["start"],
frequency=sampled_params["frequency"],
)
|