DynaMix / dynamix /forecaster.py
Dschobby's picture
Upload 14 files
776877d verified
raw
history blame
9.17 kB
import torch
import torch.nn.functional as F
from .preprocessing import DataPreprocessor
class DynaMixForecaster:
"""
Forecasting pipeline for DynaMix models with batch processing support.
"""
def __init__(self, model):
"""
Initialize the forecaster with a DynaMix model.
Args:
model: DynaMix model instance
"""
self.model = model
def _init_latent_state(self, initial_condition):
"""
Initialize the latent state from the initial condition.
Args:
initial_condition: Initial state of shape (batch_size, N)
Returns:
Initial latent state z
"""
N = self.model.N
# Initialize latent state
z = torch.matmul(initial_condition, self.model.B).t() # (M, batch_size)
z[:N, :] = initial_condition.t()
return z
def _reshape_for_model(self, context, initial_x=None, device=None):
"""
Prepare and reshape input data for the model.
Handles tensor conversion, dimension adjustments, and reshaping when feature_dim > model_dim.
Args:
context: Context data tensor of shape (seq_length, batch_size, feature_dim) or (seq_length, feature_dim)
initial_x: Optional initial condition of shape (batch_size, feature_dim) or (feature_dim,)
device: Device to place tensors on
Returns:
Processed context, initial_x, dimensions, and reshaping metadata
"""
# Get the dtype from model parameters
model_dtype = next(self.model.parameters()).dtype
# Convert to torch tensor if needed
if not isinstance(context, torch.Tensor):
context = torch.tensor(context, dtype=model_dtype, device=device)
elif context.device != device or context.dtype != model_dtype:
context = context.to(device=device, dtype=model_dtype)
if initial_x is not None and not isinstance(initial_x, torch.Tensor):
initial_x = torch.tensor(initial_x, dtype=model_dtype, device=device)
elif initial_x is not None and (initial_x.device != device or initial_x.dtype != model_dtype):
initial_x = initial_x.to(device=device, dtype=model_dtype)
# Check data dimensions and reshape if needed
original_dim = context.dim()
if original_dim == 2:
context = context.unsqueeze(1) # (seq_length, feature_dim) -> (seq_length, 1, feature_dim)
elif original_dim != 3:
raise ValueError(f"Expected 2D or 3D tensor for context, got shape {context.shape} with {context.dim()} dimensions")
if initial_x is not None and initial_x.dim() == 1:
initial_x = initial_x.unsqueeze(0) # (feature_dim,) -> (1, feature_dim)
if initial_x.shape[1] != context.shape[2]:
raise ValueError(f"Initial condition has {initial_x.shape[1]} features, but context has {context.shape[2]} features")
# Data shape
seq_length, batch_size, feature_dim = context.shape
# Check if reshaping is needed for model dimension
if feature_dim <= self.model.N:
return context, initial_x, (batch_size, feature_dim, False, None, None, original_dim)
print(f"Warning: Input feature dimension {feature_dim} exceeds model dimension {self.model.N}. "
f"This may lead to performance degradation."
f"Reshaping data to treat each feature as separate time series.")
# Store original dimensions for reshaping back later
original_batch_size = batch_size
original_feature_dim = feature_dim
# Reshape context to (seq_length, batch_size * feature_dim, 1)
transposed = context.permute(0, 2, 1)
new_batch_size = batch_size * feature_dim
reshaped_context = transposed.reshape(seq_length, new_batch_size, 1)
# Similarly reshape initial_x if provided
reshaped_initial_x = initial_x
if initial_x is not None:
# Reshape from (batch_size, feature_dim) to (batch_size * feature_dim, 1)
reshaped_initial_x = initial_x.transpose(0, 1).reshape(new_batch_size, 1)
return reshaped_context, reshaped_initial_x, (new_batch_size, 1, True, original_batch_size, original_feature_dim, original_dim)
def _reshape_to_original(self, output, reshape_metadata):
"""
Reshape output back to original dimensions.
Handles both high-dimensional reshaping and 2D input restoration.
Args:
output: Model output of shape (T, batch_size, N)
reshape_metadata: Tuple containing (was_reshaped, original_batch_size, original_feature_dim, original_dim)
Returns:
Output with original shape restored
"""
_, _, was_reshaped, original_batch_size, original_feature_dim, original_dim = reshape_metadata
# Step 1: Reshape back to original dimensions if needed
if was_reshaped:
# Current shape: (T, batch_size=original_batch_size*original_feature_dim, 1)
T = output.shape[0]
# First reshape to (T, original_feature_dim, original_batch_size)
# by treating the batch dimension as (original_feature_dim, original_batch_size)
reshaped = output.reshape(T, original_feature_dim, original_batch_size, -1)
# Then permute to (T, original_batch_size, original_feature_dim)
output = reshaped.permute(0, 2, 1, 3).squeeze(-1)
# Step 2: If input was 2D, remove batch dimension from output
if original_dim == 2 and output.shape[1] == 1:
output = output.squeeze(1)
return output
@torch.no_grad()
def forecast(self, context, horizon, preprocessing_method="pos_embedding",
standardize=True, fit_nonstationary=False, initial_x=None):
"""
Efficient batched forecasting with the DynaMix model.
This method implements a complete forecasting pipeline including:
- Data preprocessing (Box-Cox, detrending, standardization)
- Embedding techniques for dimensionality matching
- DynaMix model prediction
- Data postprocessing (inverse transformations)
Args:
context: Context data tensor of shape (seq_length, batch_size, feature_dim) or (seq_length, feature_dim)
horizon: Forecast horizon (number of steps to predict)
preprocessing_method: Data preprocessing method ('pos_embedding', 'zero_embedding',
'delay_embedding', or 'delay_embedding_random') (default: 'pos_embedding')
standardize: Whether to standardize the data (default: True)
fit_nonstationary: Whether to fit a non-stationary time series (default: False)
initial_x: Optional initial condition of shape (batch_size, feature_dim) or (feature_dim,)
Returns:
Predicted sequence of shape (horizon, batch_size, feature_dim)
"""
# Get model dimensions
M = self.model.M
N = self.model.N
device = context.device if isinstance(context, torch.Tensor) else self.model.B.device
model_dtype = next(self.model.parameters()).dtype
# Apply context reshaping if needed
context, initial_x, shape_metadata = self._reshape_for_model(context, initial_x, device)
# Create data preprocessor
preprocessor = DataPreprocessor(
standardize=standardize,
box_cox=fit_nonstationary,
detrending=fit_nonstationary,
preprocessing_method=preprocessing_method
)
# Step 1: Apply preprocessing pipeline
context_embedded, initial_condition = preprocessor.preprocess(context, self.model.N, initial_x)
# Step 2: Initialize latent state
z = self._init_latent_state(initial_condition)
# Step 3: Perform forecasting loop
Z_gen = torch.empty(horizon, M, shape_metadata[0], device=device, dtype=model_dtype)
with torch.amp.autocast(device_type='cuda' if device.type == 'cuda' else 'cpu', enabled=device.type == 'cuda'):
precomputed_cnn = self.model.precompute_cnn(context_embedded)
for t in range(horizon):
z = self.model(z, context_embedded, precomputed_cnn=precomputed_cnn)
Z_gen[t] = z
# Step 4: Apply observation generation
output = Z_gen[:, :shape_metadata[1], :].permute(0, 2, 1) # (horizon, batch_size, feature_dim)
# Step 5: Apply inverse data transformations (e.g. standardization, ...)
output = preprocessor.postprocess(output)
# Step 6: Reshape back to original dimensions if needed
output = self._reshape_to_original(output, shape_metadata)
return output