Spaces:
Running
Running
| import torch | |
| import torch.nn.functional as F | |
| from .preprocessing import DataPreprocessor | |
| class DynaMixForecaster: | |
| """ | |
| Forecasting pipeline for DynaMix models with batch processing support. | |
| """ | |
| def __init__(self, model): | |
| """ | |
| Initialize the forecaster with a DynaMix model. | |
| Args: | |
| model: DynaMix model instance | |
| """ | |
| self.model = model | |
| def _init_latent_state(self, initial_condition): | |
| """ | |
| Initialize the latent state from the initial condition. | |
| Args: | |
| initial_condition: Initial state of shape (batch_size, N) | |
| Returns: | |
| Initial latent state z | |
| """ | |
| N = self.model.N | |
| # Initialize latent state | |
| z = torch.matmul(initial_condition, self.model.B).t() # (M, batch_size) | |
| z[:N, :] = initial_condition.t() | |
| return z | |
| def _reshape_for_model(self, context, initial_x=None, device=None): | |
| """ | |
| Prepare and reshape input data for the model. | |
| Handles tensor conversion, dimension adjustments, and reshaping when feature_dim > model_dim. | |
| Args: | |
| context: Context data tensor of shape (seq_length, batch_size, feature_dim) or (seq_length, feature_dim) | |
| initial_x: Optional initial condition of shape (batch_size, feature_dim) or (feature_dim,) | |
| device: Device to place tensors on | |
| Returns: | |
| Processed context, initial_x, dimensions, and reshaping metadata | |
| """ | |
| # Get the dtype from model parameters | |
| model_dtype = next(self.model.parameters()).dtype | |
| # Convert to torch tensor if needed | |
| if not isinstance(context, torch.Tensor): | |
| context = torch.tensor(context, dtype=model_dtype, device=device) | |
| elif context.device != device or context.dtype != model_dtype: | |
| context = context.to(device=device, dtype=model_dtype) | |
| if initial_x is not None and not isinstance(initial_x, torch.Tensor): | |
| initial_x = torch.tensor(initial_x, dtype=model_dtype, device=device) | |
| elif initial_x is not None and (initial_x.device != device or initial_x.dtype != model_dtype): | |
| initial_x = initial_x.to(device=device, dtype=model_dtype) | |
| # Check data dimensions and reshape if needed | |
| original_dim = context.dim() | |
| if original_dim == 2: | |
| context = context.unsqueeze(1) # (seq_length, feature_dim) -> (seq_length, 1, feature_dim) | |
| elif original_dim != 3: | |
| raise ValueError(f"Expected 2D or 3D tensor for context, got shape {context.shape} with {context.dim()} dimensions") | |
| if initial_x is not None and initial_x.dim() == 1: | |
| initial_x = initial_x.unsqueeze(0) # (feature_dim,) -> (1, feature_dim) | |
| if initial_x.shape[1] != context.shape[2]: | |
| raise ValueError(f"Initial condition has {initial_x.shape[1]} features, but context has {context.shape[2]} features") | |
| # Data shape | |
| seq_length, batch_size, feature_dim = context.shape | |
| # Check if reshaping is needed for model dimension | |
| if feature_dim <= self.model.N: | |
| return context, initial_x, (batch_size, feature_dim, False, None, None, original_dim) | |
| print(f"Warning: Input feature dimension {feature_dim} exceeds model dimension {self.model.N}. " | |
| f"This may lead to performance degradation." | |
| f"Reshaping data to treat each feature as separate time series.") | |
| # Store original dimensions for reshaping back later | |
| original_batch_size = batch_size | |
| original_feature_dim = feature_dim | |
| # Reshape context to (seq_length, batch_size * feature_dim, 1) | |
| transposed = context.permute(0, 2, 1) | |
| new_batch_size = batch_size * feature_dim | |
| reshaped_context = transposed.reshape(seq_length, new_batch_size, 1) | |
| # Similarly reshape initial_x if provided | |
| reshaped_initial_x = initial_x | |
| if initial_x is not None: | |
| # Reshape from (batch_size, feature_dim) to (batch_size * feature_dim, 1) | |
| reshaped_initial_x = initial_x.transpose(0, 1).reshape(new_batch_size, 1) | |
| return reshaped_context, reshaped_initial_x, (new_batch_size, 1, True, original_batch_size, original_feature_dim, original_dim) | |
| def _reshape_to_original(self, output, reshape_metadata): | |
| """ | |
| Reshape output back to original dimensions. | |
| Handles both high-dimensional reshaping and 2D input restoration. | |
| Args: | |
| output: Model output of shape (T, batch_size, N) | |
| reshape_metadata: Tuple containing (was_reshaped, original_batch_size, original_feature_dim, original_dim) | |
| Returns: | |
| Output with original shape restored | |
| """ | |
| _, _, was_reshaped, original_batch_size, original_feature_dim, original_dim = reshape_metadata | |
| # Step 1: Reshape back to original dimensions if needed | |
| if was_reshaped: | |
| # Current shape: (T, batch_size=original_batch_size*original_feature_dim, 1) | |
| T = output.shape[0] | |
| # First reshape to (T, original_feature_dim, original_batch_size) | |
| # by treating the batch dimension as (original_feature_dim, original_batch_size) | |
| reshaped = output.reshape(T, original_feature_dim, original_batch_size, -1) | |
| # Then permute to (T, original_batch_size, original_feature_dim) | |
| output = reshaped.permute(0, 2, 1, 3).squeeze(-1) | |
| # Step 2: If input was 2D, remove batch dimension from output | |
| if original_dim == 2 and output.shape[1] == 1: | |
| output = output.squeeze(1) | |
| return output | |
| def forecast(self, context, horizon, preprocessing_method="pos_embedding", | |
| standardize=True, fit_nonstationary=False, initial_x=None): | |
| """ | |
| Efficient batched forecasting with the DynaMix model. | |
| This method implements a complete forecasting pipeline including: | |
| - Data preprocessing (Box-Cox, detrending, standardization) | |
| - Embedding techniques for dimensionality matching | |
| - DynaMix model prediction | |
| - Data postprocessing (inverse transformations) | |
| Args: | |
| context: Context data tensor of shape (seq_length, batch_size, feature_dim) or (seq_length, feature_dim) | |
| horizon: Forecast horizon (number of steps to predict) | |
| preprocessing_method: Data preprocessing method ('pos_embedding', 'zero_embedding', | |
| 'delay_embedding', or 'delay_embedding_random') (default: 'pos_embedding') | |
| standardize: Whether to standardize the data (default: True) | |
| fit_nonstationary: Whether to fit a non-stationary time series (default: False) | |
| initial_x: Optional initial condition of shape (batch_size, feature_dim) or (feature_dim,) | |
| Returns: | |
| Predicted sequence of shape (horizon, batch_size, feature_dim) | |
| """ | |
| # Get model dimensions | |
| M = self.model.M | |
| N = self.model.N | |
| device = context.device if isinstance(context, torch.Tensor) else self.model.B.device | |
| model_dtype = next(self.model.parameters()).dtype | |
| # Apply context reshaping if needed | |
| context, initial_x, shape_metadata = self._reshape_for_model(context, initial_x, device) | |
| # Create data preprocessor | |
| preprocessor = DataPreprocessor( | |
| standardize=standardize, | |
| box_cox=fit_nonstationary, | |
| detrending=fit_nonstationary, | |
| preprocessing_method=preprocessing_method | |
| ) | |
| # Step 1: Apply preprocessing pipeline | |
| context_embedded, initial_condition = preprocessor.preprocess(context, self.model.N, initial_x) | |
| # Step 2: Initialize latent state | |
| z = self._init_latent_state(initial_condition) | |
| # Step 3: Perform forecasting loop | |
| Z_gen = torch.empty(horizon, M, shape_metadata[0], device=device, dtype=model_dtype) | |
| with torch.amp.autocast(device_type='cuda' if device.type == 'cuda' else 'cpu', enabled=device.type == 'cuda'): | |
| precomputed_cnn = self.model.precompute_cnn(context_embedded) | |
| for t in range(horizon): | |
| z = self.model(z, context_embedded, precomputed_cnn=precomputed_cnn) | |
| Z_gen[t] = z | |
| # Step 4: Apply observation generation | |
| output = Z_gen[:, :shape_metadata[1], :].permute(0, 2, 1) # (horizon, batch_size, feature_dim) | |
| # Step 5: Apply inverse data transformations (e.g. standardization, ...) | |
| output = preprocessor.postprocess(output) | |
| # Step 6: Reshape back to original dimensions if needed | |
| output = self._reshape_to_original(output, shape_metadata) | |
| return output | |