Dmitry Beresnev
add wavelet analysis
0821f38
"""MODWT core β€” causal, shift-invariant wavelet decomposition.
Uses PyWavelets swt() (stationary wavelet transform = MODWT) with:
- Reflective left-padding to satisfy power-of-2 length requirement
- Boundary trim on the right to discard contaminated coefficients
- norm=True to preserve additive variance decomposition
Reference: Percival & Walden (2000), Wavelet Methods for Time Series Analysis.
"""
from __future__ import annotations
import numpy as np
import pywt
def _next_power_of_two(n: int) -> int:
p = 1
while p < n:
p <<= 1
return p
def modwt_details_causal(
series: np.ndarray,
wavelet: str = "sym8",
level: int = 6,
) -> dict[int, np.ndarray]:
"""MODWT decomposition returning detail coefficients D1..D_level.
Strictly causal: the input must already contain only past data (the caller
is responsible for windowing). Reflective padding is applied on the LEFT
only, so right-edge coefficients are contaminated only by the boundary
effect of the filter support β€” not by future prices.
Args:
series: 1-D array of values (typically log-prices), length >= 2^level
wavelet: PyWavelets wavelet name (default "sym8")
level: decomposition depth (default 6)
Returns:
dict mapping level index (1..level) β†’ detail coefficient array,
each the same length as `series` (trimmed from padded output).
Raises:
ValueError: if series is too short for the requested level.
"""
n = len(series)
if n < 2 ** level:
raise ValueError(
f"Series length {n} is too short for level {level} MODWT "
f"(need >= {2**level})"
)
# pywt.swt requires length to be a multiple of 2^level
required = _next_power_of_two(n)
if required < 2 ** level:
required = 2 ** level
pad_left = required - n
if pad_left > 0:
# Reflective padding on the left β€” does not introduce future information
padded = np.concatenate([series[:pad_left][::-1], series])
else:
padded = series.copy()
# swt with norm=True preserves additive variance decomposition
# Returns list of (cA, cD) tuples from level 1 up to `level`
coeffs = pywt.swt(padded, wavelet=wavelet, level=level, norm=True)
# Extract details D1..D_level and trim back to original length
details: dict[int, np.ndarray] = {}
for j, (cA, cD) in enumerate(coeffs, start=1):
details[j] = cD[-n:]
return details
def reconstruct_midband(
details: dict[int, np.ndarray],
levels: list[int],
) -> np.ndarray:
"""Sum the specified detail levels to produce a mid-band component.
Args:
details: output of modwt_details_causal
levels: e.g. [4, 5] for the 16–64 day band
Returns:
Sum of requested detail arrays (same length as each detail).
"""
present = [j for j in levels if j in details]
if not present:
raise ValueError(f"None of levels {levels} found in details keys {list(details)}")
return sum(details[j] for j in present) # type: ignore[return-value]
def trim_boundary(arr: np.ndarray, max_level: int) -> np.ndarray:
"""Discard the rightmost 2^(max_level-1) samples β€” boundary-contaminated region.
Args:
arr: detail or mid-band array
max_level: highest signal level used (e.g. 5 when using D4+D5)
Returns:
Trimmed array (shorter by 2^(max_level-1)).
"""
trim = 2 ** (max_level - 1)
if len(arr) <= trim:
return arr # too short to trim; caller must guard
return arr[:-trim]