Drought-API / utils.py
l1aF2027's picture
Upload 9 files
7ccf60d verified
import numpy as np
import pandas as pd
from datetime import datetime
from scipy.interpolate import interp1d
from tqdm.auto import tqdm
def normalize(X_static, X_time, scaler_dict=None, scaler_dict_static=None):
"""
Normalize time series and static data using the pre-fitted scalers
Args:
X_static: Static data of shape (batch_size, static_dim)
X_time: Time series data of shape (batch_size, seq_len, time_dim)
scaler_dict: Dictionary of scalers for time series data
scaler_dict_static: Dictionary of scalers for static data
Returns:
X_static_norm: Normalized static data
X_time_norm: Normalized time series data
"""
# Make a copy to avoid modifying the original data
X_static_norm = X_static.copy()
X_time_norm = X_time.copy()
# Normalize time series data
for index in range(X_time_norm.shape[-1]):
if index in scaler_dict:
X_time_norm[:, :, index] = (
scaler_dict[index]
.transform(X_time_norm[:, :, index].reshape(-1, 1))
.reshape(-1, X_time_norm.shape[-2])
)
# Normalize static data
for index in range(X_static_norm.shape[-1]):
if index in scaler_dict_static:
X_static_norm[:, index] = (
scaler_dict_static[index]
.transform(X_static_norm[:, index].reshape(-1, 1))
.reshape(1, -1)
)
return X_static_norm, X_time_norm
def interpolate_nans(padata, pkind='linear'):
"""
Interpolate missing values in an array
Args:
padata: Array with possible NaN values
pkind: Kind of interpolation ('linear', 'cubic', etc.)
Returns:
interpolated_data: Array with NaN values interpolated
"""
aindexes = np.arange(padata.shape[0])
agood_indexes, = np.where(np.isfinite(padata))
# If all values are NaN or there's only one good value, return zeros
if len(agood_indexes) == 0:
return np.zeros_like(padata)
elif len(agood_indexes) == 1:
# If there's only one good value, fill with that value
result = np.full_like(padata, padata[agood_indexes[0]])
return result
# Interpolate
f = interp1d(
agood_indexes,
padata[agood_indexes],
bounds_error=False,
copy=False,
fill_value="extrapolate",
kind=pkind
)
return f(aindexes)
def date_encode(date):
"""
Encode date as sine and cosine components to capture cyclical patterns
Args:
date: Date to encode, can be string or datetime object
Returns:
sin_day: Sine component of day of year
cos_day: Cosine component of day of year
"""
if isinstance(date, str):
date = datetime.strptime(date, "%Y-%m-%d")
# Get day of year (1-366)
day_of_year = date.timetuple().tm_yday
# Encode as sine and cosine
sin_day = np.sin(2 * np.pi * day_of_year / 366)
cos_day = np.cos(2 * np.pi * day_of_year / 366)
return sin_day, cos_day