File size: 3,220 Bytes
7ccf60d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import pandas as pd
from datetime import datetime
from scipy.interpolate import interp1d
from tqdm.auto import tqdm

def normalize(X_static, X_time, scaler_dict=None, scaler_dict_static=None):
    """

    Normalize time series and static data using the pre-fitted scalers

    

    Args:

        X_static: Static data of shape (batch_size, static_dim)

        X_time: Time series data of shape (batch_size, seq_len, time_dim)

        scaler_dict: Dictionary of scalers for time series data

        scaler_dict_static: Dictionary of scalers for static data

        

    Returns:

        X_static_norm: Normalized static data

        X_time_norm: Normalized time series data

    """
    # Make a copy to avoid modifying the original data
    X_static_norm = X_static.copy()
    X_time_norm = X_time.copy()
    
    # Normalize time series data
    for index in range(X_time_norm.shape[-1]):
        if index in scaler_dict:
            X_time_norm[:, :, index] = (
                scaler_dict[index]
                .transform(X_time_norm[:, :, index].reshape(-1, 1))
                .reshape(-1, X_time_norm.shape[-2])
            )
    
    # Normalize static data
    for index in range(X_static_norm.shape[-1]):
        if index in scaler_dict_static:
            X_static_norm[:, index] = (
                scaler_dict_static[index]
                .transform(X_static_norm[:, index].reshape(-1, 1))
                .reshape(1, -1)
            )
    
    return X_static_norm, X_time_norm

def interpolate_nans(padata, pkind='linear'):
    """

    Interpolate missing values in an array

    

    Args:

        padata: Array with possible NaN values

        pkind: Kind of interpolation ('linear', 'cubic', etc.)

        

    Returns:

        interpolated_data: Array with NaN values interpolated

    """
    aindexes = np.arange(padata.shape[0])
    agood_indexes, = np.where(np.isfinite(padata))
    
    # If all values are NaN or there's only one good value, return zeros
    if len(agood_indexes) == 0:
        return np.zeros_like(padata)
    elif len(agood_indexes) == 1:
        # If there's only one good value, fill with that value
        result = np.full_like(padata, padata[agood_indexes[0]])
        return result
    
    # Interpolate
    f = interp1d(
        agood_indexes,
        padata[agood_indexes],
        bounds_error=False,
        copy=False,
        fill_value="extrapolate",
        kind=pkind
    )
    
    return f(aindexes)

def date_encode(date):
    """

    Encode date as sine and cosine components to capture cyclical patterns

    

    Args:

        date: Date to encode, can be string or datetime object

        

    Returns:

        sin_day: Sine component of day of year

        cos_day: Cosine component of day of year

    """
    if isinstance(date, str):
        date = datetime.strptime(date, "%Y-%m-%d")
    
    # Get day of year (1-366)
    day_of_year = date.timetuple().tm_yday
    
    # Encode as sine and cosine
    sin_day = np.sin(2 * np.pi * day_of_year / 366)
    cos_day = np.cos(2 * np.pi * day_of_year / 366)
    
    return sin_day, cos_day