"""MODWT core — causal, shift-invariant wavelet decomposition. Uses PyWavelets swt() (stationary wavelet transform = MODWT) with: - Reflective left-padding to satisfy power-of-2 length requirement - Boundary trim on the right to discard contaminated coefficients - norm=True to preserve additive variance decomposition Reference: Percival & Walden (2000), Wavelet Methods for Time Series Analysis. """ from __future__ import annotations import numpy as np import pywt def _next_power_of_two(n: int) -> int: p = 1 while p < n: p <<= 1 return p def modwt_details_causal( series: np.ndarray, wavelet: str = "sym8", level: int = 6, ) -> dict[int, np.ndarray]: """MODWT decomposition returning detail coefficients D1..D_level. Strictly causal: the input must already contain only past data (the caller is responsible for windowing). Reflective padding is applied on the LEFT only, so right-edge coefficients are contaminated only by the boundary effect of the filter support — not by future prices. Args: series: 1-D array of values (typically log-prices), length >= 2^level wavelet: PyWavelets wavelet name (default "sym8") level: decomposition depth (default 6) Returns: dict mapping level index (1..level) → detail coefficient array, each the same length as `series` (trimmed from padded output). Raises: ValueError: if series is too short for the requested level. """ n = len(series) if n < 2 ** level: raise ValueError( f"Series length {n} is too short for level {level} MODWT " f"(need >= {2**level})" ) # pywt.swt requires length to be a multiple of 2^level required = _next_power_of_two(n) if required < 2 ** level: required = 2 ** level pad_left = required - n if pad_left > 0: # Reflective padding on the left — does not introduce future information padded = np.concatenate([series[:pad_left][::-1], series]) else: padded = series.copy() # swt with norm=True preserves additive variance decomposition # Returns list of (cA, cD) tuples from level 1 up to `level` coeffs = pywt.swt(padded, wavelet=wavelet, level=level, norm=True) # Extract details D1..D_level and trim back to original length details: dict[int, np.ndarray] = {} for j, (cA, cD) in enumerate(coeffs, start=1): details[j] = cD[-n:] return details def reconstruct_midband( details: dict[int, np.ndarray], levels: list[int], ) -> np.ndarray: """Sum the specified detail levels to produce a mid-band component. Args: details: output of modwt_details_causal levels: e.g. [4, 5] for the 16–64 day band Returns: Sum of requested detail arrays (same length as each detail). """ present = [j for j in levels if j in details] if not present: raise ValueError(f"None of levels {levels} found in details keys {list(details)}") return sum(details[j] for j in present) # type: ignore[return-value] def trim_boundary(arr: np.ndarray, max_level: int) -> np.ndarray: """Discard the rightmost 2^(max_level-1) samples — boundary-contaminated region. Args: arr: detail or mid-band array max_level: highest signal level used (e.g. 5 when using D4+D5) Returns: Trimmed array (shorter by 2^(max_level-1)). """ trim = 2 ** (max_level - 1) if len(arr) <= trim: return arr # too short to trim; caller must guard return arr[:-trim]