File size: 1,920 Bytes
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
from src.data.containers import TimeSeriesContainer
from src.synthetic_generation.abstract_classes import GeneratorWrapper
from src.synthetic_generation.generator_params import StepGeneratorParams
from src.synthetic_generation.steps.step_generator import StepGenerator


class StepGeneratorWrapper(GeneratorWrapper):
    """
    Wrapper for StepGenerator that handles batch generation and formatting.
    """

    def __init__(self, params: StepGeneratorParams):
        """
        Initialize the StepGeneratorWrapper.

        Parameters
        ----------
        params : StepGeneratorParams
            Parameters for the step generator.
        """
        super().__init__(params)
        self.generator = StepGenerator(params)

    def generate_batch(self, batch_size: int, seed: int | None = None) -> TimeSeriesContainer:
        """
        Generate a batch of step function time series.

        Parameters
        ----------
        batch_size : int
            Number of time series to generate.
        seed : int, optional
            Random seed for reproducibility.

        Returns
        -------
        TimeSeriesContainer
            TimeSeriesContainer containing the generated time series.
        """
        if seed is not None:
            self._set_random_seeds(seed)

        # Sample parameters for the batch
        sampled_params = self._sample_parameters(batch_size)

        # Generate time series
        values = []
        for i in range(batch_size):
            # Use a different seed for each series in the batch
            series_seed = (seed + i) if seed is not None else None
            series = self.generator.generate_time_series(series_seed)
            values.append(series)

        return TimeSeriesContainer(
            values=np.array(values),
            start=sampled_params["start"],
            frequency=sampled_params["frequency"],
        )