import numpy as np import pandas as pd from datetime import datetime from scipy.interpolate import interp1d from tqdm.auto import tqdm def normalize(X_static, X_time, scaler_dict=None, scaler_dict_static=None): """ Normalize time series and static data using the pre-fitted scalers Args: X_static: Static data of shape (batch_size, static_dim) X_time: Time series data of shape (batch_size, seq_len, time_dim) scaler_dict: Dictionary of scalers for time series data scaler_dict_static: Dictionary of scalers for static data Returns: X_static_norm: Normalized static data X_time_norm: Normalized time series data """ # Make a copy to avoid modifying the original data X_static_norm = X_static.copy() X_time_norm = X_time.copy() # Normalize time series data for index in range(X_time_norm.shape[-1]): if index in scaler_dict: X_time_norm[:, :, index] = ( scaler_dict[index] .transform(X_time_norm[:, :, index].reshape(-1, 1)) .reshape(-1, X_time_norm.shape[-2]) ) # Normalize static data for index in range(X_static_norm.shape[-1]): if index in scaler_dict_static: X_static_norm[:, index] = ( scaler_dict_static[index] .transform(X_static_norm[:, index].reshape(-1, 1)) .reshape(1, -1) ) return X_static_norm, X_time_norm def interpolate_nans(padata, pkind='linear'): """ Interpolate missing values in an array Args: padata: Array with possible NaN values pkind: Kind of interpolation ('linear', 'cubic', etc.) Returns: interpolated_data: Array with NaN values interpolated """ aindexes = np.arange(padata.shape[0]) agood_indexes, = np.where(np.isfinite(padata)) # If all values are NaN or there's only one good value, return zeros if len(agood_indexes) == 0: return np.zeros_like(padata) elif len(agood_indexes) == 1: # If there's only one good value, fill with that value result = np.full_like(padata, padata[agood_indexes[0]]) return result # Interpolate f = interp1d( agood_indexes, padata[agood_indexes], bounds_error=False, copy=False, fill_value="extrapolate", kind=pkind ) return f(aindexes) def date_encode(date): """ Encode date as sine and cosine components to capture cyclical patterns Args: date: Date to encode, can be string or datetime object Returns: sin_day: Sine component of day of year cos_day: Cosine component of day of year """ if isinstance(date, str): date = datetime.strptime(date, "%Y-%m-%d") # Get day of year (1-366) day_of_year = date.timetuple().tm_yday # Encode as sine and cosine sin_day = np.sin(2 * np.pi * day_of_year / 366) cos_day = np.cos(2 * np.pi * day_of_year / 366) return sin_day, cos_day