File size: 3,616 Bytes
0821f38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""MODWT core β€” causal, shift-invariant wavelet decomposition.

Uses PyWavelets swt() (stationary wavelet transform = MODWT) with:
  - Reflective left-padding to satisfy power-of-2 length requirement
  - Boundary trim on the right to discard contaminated coefficients
  - norm=True to preserve additive variance decomposition

Reference: Percival & Walden (2000), Wavelet Methods for Time Series Analysis.
"""
from __future__ import annotations
import numpy as np
import pywt


def _next_power_of_two(n: int) -> int:
    p = 1
    while p < n:
        p <<= 1
    return p


def modwt_details_causal(
    series: np.ndarray,
    wavelet: str = "sym8",
    level: int = 6,
) -> dict[int, np.ndarray]:
    """MODWT decomposition returning detail coefficients D1..D_level.

    Strictly causal: the input must already contain only past data (the caller
    is responsible for windowing). Reflective padding is applied on the LEFT
    only, so right-edge coefficients are contaminated only by the boundary
    effect of the filter support β€” not by future prices.

    Args:
        series:  1-D array of values (typically log-prices), length >= 2^level
        wavelet: PyWavelets wavelet name (default "sym8")
        level:   decomposition depth (default 6)

    Returns:
        dict mapping level index (1..level) β†’ detail coefficient array,
        each the same length as `series` (trimmed from padded output).

    Raises:
        ValueError: if series is too short for the requested level.
    """
    n = len(series)
    if n < 2 ** level:
        raise ValueError(
            f"Series length {n} is too short for level {level} MODWT "
            f"(need >= {2**level})"
        )

    # pywt.swt requires length to be a multiple of 2^level
    required = _next_power_of_two(n)
    if required < 2 ** level:
        required = 2 ** level

    pad_left = required - n
    if pad_left > 0:
        # Reflective padding on the left β€” does not introduce future information
        padded = np.concatenate([series[:pad_left][::-1], series])
    else:
        padded = series.copy()

    # swt with norm=True preserves additive variance decomposition
    # Returns list of (cA, cD) tuples from level 1 up to `level`
    coeffs = pywt.swt(padded, wavelet=wavelet, level=level, norm=True)

    # Extract details D1..D_level and trim back to original length
    details: dict[int, np.ndarray] = {}
    for j, (cA, cD) in enumerate(coeffs, start=1):
        details[j] = cD[-n:]

    return details


def reconstruct_midband(
    details: dict[int, np.ndarray],
    levels: list[int],
) -> np.ndarray:
    """Sum the specified detail levels to produce a mid-band component.

    Args:
        details:  output of modwt_details_causal
        levels:   e.g. [4, 5] for the 16–64 day band

    Returns:
        Sum of requested detail arrays (same length as each detail).
    """
    present = [j for j in levels if j in details]
    if not present:
        raise ValueError(f"None of levels {levels} found in details keys {list(details)}")
    return sum(details[j] for j in present)  # type: ignore[return-value]


def trim_boundary(arr: np.ndarray, max_level: int) -> np.ndarray:
    """Discard the rightmost 2^(max_level-1) samples β€” boundary-contaminated region.

    Args:
        arr:        detail or mid-band array
        max_level:  highest signal level used (e.g. 5 when using D4+D5)

    Returns:
        Trimmed array (shorter by 2^(max_level-1)).
    """
    trim = 2 ** (max_level - 1)
    if len(arr) <= trim:
        return arr  # too short to trim; caller must guard
    return arr[:-trim]