DynaMix / dynamix /preprocessing.py
Dschobby's picture
Upload 14 files
776877d verified
raw
history blame
10.4 kB
import torch
import numpy as np
from .preprocessing_utilities import (TimeSeriesProcessor, Embedding,
BoxCoxTransformer, Detrending, estimate_initial_condition)
class DataPreprocessor:
"""
Main class for data preprocessing that orchestrates all transformations.
"""
def __init__(self, standardize=True, box_cox=False, detrending=False, preprocessing_method="pos_embedding"):
"""
Initialize the data preprocessor.
Args:
standardize: Whether to standardize the data
box_cox: Whether to apply Box-Cox transformation
detrending: Whether to apply exponential detrending
preprocessing_method: Method for embedding ('pos_embedding', 'zero_embedding',
'delay_embedding', 'delay_embedding_random')
"""
self.standardize = standardize
self.box_cox = box_cox
self.detrending = detrending
self.preprocessing_method = preprocessing_method
# Parameters for inverse transformations
self.box_cox_params_list = None
self.detrending_params_list = None
self.context_mean = None
self.context_std = None
self.original_context = None
self.batch_size = None
self.feature_dim = None
def _apply_transformations(self, context):
"""
Apply Box-Cox transformation and/or detrending to each batch in the context data.
Args:
context: Context data tensor of shape (seq_length, batch_size, N_data)
Returns:
Transformed context data
"""
# Store original context for inverse transformations
self.original_context = context.clone()
# Apply Box-Cox transformation for each batch
if self.box_cox:
transformed_context = torch.zeros_like(context)
self.box_cox_params_list = []
for b in range(self.batch_size):
batch_context = context[:, b, :]
transformed, params = BoxCoxTransformer.transform(batch_context)
transformed_context[:, b, :] = transformed
self.box_cox_params_list.append(params)
context = transformed_context
# Apply detrending for each batch
if self.detrending:
detrended_context = torch.zeros_like(context)
self.detrending_params_list = []
for b in range(self.batch_size):
batch_context = context[:, b, :]
detrended, params = Detrending.apply_detrending(batch_context)
detrended_context[:, b, :] = detrended
self.detrending_params_list.append(params)
context = detrended_context
return context
def _apply_transformations_inverse(self, output):
"""
Apply inverse Box-Cox and detrending transformations.
Args:
output: Model output of shape (T, batch_size, N)
Returns:
Output with transformations reversed
"""
# Apply inverse detrending for each batch
if self.detrending and self.detrending_params_list is not None:
for b in range(self.batch_size):
batch_output = output[:, b, :]
batch_context = self.original_context[:, b, :]
batch_output = Detrending.apply_detrending_inverse(batch_context, batch_output, self.detrending_params_list[b])
output[:, b, :] = batch_output
# Apply inverse Box-Cox transformation for each batch
if self.box_cox and self.box_cox_params_list is not None:
for b in range(self.batch_size):
batch_output = output[:, b, :]
batch_output = BoxCoxTransformer.inverse_transform(batch_output, self.box_cox_params_list[b])
output[:, b, :] = batch_output
return output
def _standardize_data(self, context):
"""
Standardize each batch in the context data.
Args:
context: Context data tensor of shape (seq_length, batch_size, N_data)
initial_x: Optional initial condition of shape (batch_size, N_data)
Returns:
Standardized context and initial_x (if provided)
"""
if not self.standardize:
return context
# Calculate mean and std across time dimension for each batch separately
self.context_mean = torch.mean(context, dim=0) # (batch_size, N_data)
self.context_std = torch.std(context, dim=0) # (batch_size, N_data)
self.context_std = torch.clamp(self.context_std, min=1e-6) # Avoid division by zero
# Standardize using broadcasting
context = (context - self.context_mean.unsqueeze(0)) / self.context_std.unsqueeze(0)
return context
def _unstandardize_data(self, output):
"""
Undo standardization by applying the inverse transformation.
Args:
output: Model output of shape (T, batch_size, N)
Returns:
Output with standardization reversed
"""
if self.standardize and self.context_mean is not None and self.context_std is not None:
return output * self.context_std.unsqueeze(0) + self.context_mean.unsqueeze(0)
return output
def _apply_embedding(self, context, model_dim):
"""
Apply data preprocessing to each batch to reach model dimension.
Args:
context: Context data tensor of shape (seq_length, batch_size, N_data)
model_dim: Target model dimension
Returns:
Preprocessed context data tensor
"""
context_embedded_batch = []
for b in range(self.batch_size):
batch_context = context[:, b, :]
batch_embedded = Embedding.apply_embedding(batch_context, model_dim, self.preprocessing_method)
context_embedded_batch.append(batch_embedded)
# Align sequence lengths across batches
seq_lengths = [emb.shape[0] for emb in context_embedded_batch]
min_seq_len = min(seq_lengths)
context_embedded_batch = [emb[-min_seq_len:] for emb in context_embedded_batch]
# Stack along batch dimension
return torch.stack(context_embedded_batch, dim=1)
def _prepare_initial_condition(self, context_embedded, initial_x, model_dim):
"""
Prepare initial condition for forecasting.
Args:
context_embedded: Preprocessed context data
initial_x: Optional initial condition
model_dim: Model dimension
Returns:
Initial condition for forecasting
Raises:
ValueError: If initial condition is provided with Box-Cox or detrending enabled
"""
if initial_x is None:
# Use last context value for each batch
return context_embedded[-1]
# Raise error if initial condition is provided with Box-Cox or detrending enabled
if (self.box_cox or self.detrending):
raise ValueError(
"Using initial conditions with Box-Cox or detrending is not supported. "
"Either disable Box-Cox and detrending or do not provide an initial condition."
)
# Process initial conditions for each batch
initial_x_processed = torch.zeros(self.batch_size, model_dim, device=context_embedded.device)
for b in range(self.batch_size):
batch_initial = initial_x[b]
# Apply standardization if enabled
if self.standardize and self.context_mean is not None and self.context_std is not None:
batch_initial = (batch_initial - self.context_mean[b]) / (self.context_std[b] + 1e-8)
# If dimensions are smaller than model_dim, estimate full initial condition
if initial_x.shape[1] < model_dim:
# Find matching state in context_embedded
batch_initial = estimate_initial_condition(
batch_initial,
context_embedded[:, b, :],
)
initial_x_processed[b] = batch_initial
return initial_x_processed
def preprocess(self, context, model_dim, initial_x=None):
"""
Apply the complete preprocessing pipeline to the input data.
Args:
context: Context data tensor of shape (seq_length, batch_size, N_data) or (seq_length, N_data)
model_dim: Target model dimension
initial_x: Optional initial condition of shape (batch_size, N_data) or (N_data,)
Returns:
Preprocessed context data and initial condition
"""
# Store dimensions
self.batch_size = context.shape[1]
self.feature_dim = context.shape[2]
# Apply transformations (Box-Cox, detrending)
context = self._apply_transformations(context)
# Standardize data if requested
context = self._standardize_data(context)
# Apply embedding to reach model dimension
context_embedded = self._apply_embedding(context, model_dim)
# Prepare initial batch
initial_condition = self._prepare_initial_condition(context_embedded, initial_x, model_dim)
return context_embedded, initial_condition
def postprocess(self, output):
"""
Apply inverse transformations to restore original data scaling.
Args:
output: Model output of shape (T, batch_size, N)
Returns:
Output with inverse transformations applied
"""
# Undo standardization
output = self._unstandardize_data(output)
# Apply inverse transformations (Box-Cox, detrending)
output = self._apply_transformations_inverse(output)
return output