| """WaveletAnalyzer β main entry point for wavelet analysis. |
| |
| Orchestrates: |
| 1. Data fetching (yfinance, cached) |
| 2. Full MODWT detail decomposition for visualization |
| 3. Walk-forward backtest |
| 4. Signal snapshot (current mid-band signal for the ticker) |
| 5. Returns an Analysis dataclass ready for formatting |
| |
| Usage: |
| from src.transformations.wavelets.analyzer import WaveletAnalyzer |
| result = await WaveletAnalyzer.analyze("SPY", "1d", equity=10_000) |
| """ |
| from __future__ import annotations |
| import asyncio |
| import logging |
| from dataclasses import dataclass, field |
| from datetime import datetime, timezone |
| from typing import Optional |
|
|
| import numpy as np |
| import pandas as pd |
| import yfinance as yf |
|
|
| from .backtest import run_backtest, LOOKBACK, WAVELET, LEVEL, SIG_LEVELS, SLOPE_WINDOW |
| from .signal import compute_signal, compute_midband_series, compute_all_details_series, volatility_target |
| from .stats import Stats, perf |
|
|
| logger = logging.getLogger(__name__) |
|
|
| _MIN_BARS_FOR_BACKTEST = LOOKBACK + 20 |
| _MIN_BARS_FOR_SIGNAL = 256 |
|
|
|
|
| @dataclass |
| class SignalSnapshot: |
| """Current state of the MODWT signal for a single ticker.""" |
| ticker: str |
| timeframe: str |
| current_price: float |
| last_bar_time: pd.Timestamp |
| raw_signal: float |
| sized_position: float |
| realized_vol_ann: float |
| midband_slope: float |
| midband_last: float |
| bars_used: int |
|
|
|
|
| @dataclass |
| class WaveletAnalysis: |
| """Full output of WaveletAnalyzer.analyze().""" |
| ticker: str |
| timeframe: str |
| timestamp: datetime |
|
|
| signal: SignalSnapshot |
|
|
| |
| strategy_stats: Optional[Stats] = None |
| bh_stats: Optional[Stats] = None |
| sma_stats: Optional[Stats] = None |
|
|
| |
| midband_series: Optional[pd.Series] = None |
|
|
| |
| detail_series: Optional[dict[int, pd.Series]] = None |
|
|
| |
| equity_curve: Optional[pd.Series] = None |
| bh_equity: Optional[pd.Series] = None |
| sma_equity: Optional[pd.Series] = None |
|
|
| |
| warnings: list[str] = field(default_factory=list) |
|
|
|
|
| def _load_prices(ticker: str, timeframe: str, bars: int) -> pd.Series: |
| """Fetch OHLCV via yfinance and return Close series.""" |
| period_map = { |
| "1d": f"{max(bars * 2, 365)}d", |
| "1wk": f"{max(bars * 8, 365 * 2)}d", |
| } |
| period = period_map.get(timeframe, f"{bars * 2}d") |
|
|
| df = yf.download( |
| ticker, |
| period=period, |
| interval=timeframe, |
| auto_adjust=True, |
| progress=False, |
| threads=False, |
| ) |
| if df.empty: |
| raise ValueError(f"No data returned for {ticker} ({timeframe})") |
|
|
| if isinstance(df.columns, pd.MultiIndex): |
| df.columns = df.columns.get_level_values(0) |
|
|
| close = df["Close"].dropna() |
| close.name = ticker |
| return close.tail(bars) |
|
|
|
|
| def _compute_midband_slope(midband_safe: np.ndarray, slope_window: int) -> float: |
| if len(midband_safe) < slope_window: |
| return 0.0 |
| window = midband_safe[-slope_window:] |
| n = len(window) |
| x = np.arange(n, dtype=float) - (n - 1) / 2 |
| denom = (x * x).sum() |
| return float((x * (window - window.mean())).sum() / denom) if denom != 0 else 0.0 |
|
|
|
|
| class WaveletAnalyzer: |
| """Stateless wavelet analysis for a ticker + timeframe.""" |
|
|
| @staticmethod |
| async def analyze( |
| ticker: str, |
| timeframe: str = "1d", |
| equity: float = 10_000.0, |
| run_full_backtest: bool = True, |
| wavelet: str = WAVELET, |
| level: int = LEVEL, |
| sig_levels: list[int] | None = None, |
| ) -> WaveletAnalysis: |
| """Full analysis: signal snapshot + optional walk-forward backtest. |
| |
| Runs in executor to avoid blocking the async event loop. |
| """ |
| loop = asyncio.get_event_loop() |
| return await loop.run_in_executor( |
| None, |
| lambda: WaveletAnalyzer._analyze_sync( |
| ticker, timeframe, equity, run_full_backtest, wavelet, level, sig_levels |
| ), |
| ) |
|
|
| @staticmethod |
| def _analyze_sync( |
| ticker: str, |
| timeframe: str, |
| equity: float, |
| run_full_backtest: bool, |
| wavelet: str, |
| level: int, |
| sig_levels: list[int] | None, |
| ) -> WaveletAnalysis: |
| if sig_levels is None: |
| sig_levels = SIG_LEVELS |
|
|
| warnings: list[str] = [] |
| max_sig_level = max(sig_levels) |
| trim = 2 ** (max_sig_level - 1) |
|
|
| |
| bars_to_fetch = _MIN_BARS_FOR_BACKTEST + 50 if run_full_backtest else _MIN_BARS_FOR_SIGNAL + 50 |
| try: |
| prices = _load_prices(ticker, timeframe, bars_to_fetch) |
| except Exception as e: |
| raise ValueError(f"Data fetch failed for {ticker}: {e}") from e |
|
|
| if len(prices) < _MIN_BARS_FOR_SIGNAL: |
| raise ValueError( |
| f"Only {len(prices)} bars available for {ticker} β need {_MIN_BARS_FOR_SIGNAL}+" |
| ) |
|
|
| log_prices = np.log(prices.values) |
| daily_ret = np.diff(log_prices, prepend=log_prices[0]) |
|
|
| |
| signal_window_size = min(LOOKBACK, len(prices)) |
| signal_window = log_prices[-signal_window_size:] |
|
|
| raw_sig = compute_signal( |
| signal_window, |
| wavelet=wavelet, |
| level=level, |
| sig_levels=sig_levels, |
| slope_window=SLOPE_WINDOW, |
| ) |
|
|
| |
| vol_window = min(60, len(daily_ret) - 1) |
| realized_vol = float(np.std(daily_ret[-vol_window:]) * np.sqrt(252)) |
|
|
| sized_pos = volatility_target(raw_sig, realized_vol) |
|
|
| |
| from .modwt import modwt_details_causal, reconstruct_midband, trim_boundary |
| try: |
| details = modwt_details_causal(signal_window, wavelet=wavelet, level=level) |
| midband = reconstruct_midband(details, sig_levels) |
| safe_midband = trim_boundary(midband, max_level=max_sig_level) |
| except Exception: |
| safe_midband = np.zeros(1) |
|
|
| slope = _compute_midband_slope(safe_midband, SLOPE_WINDOW) |
| midband_last = float(safe_midband[-1]) if len(safe_midband) > 0 else 0.0 |
|
|
| snapshot = SignalSnapshot( |
| ticker=ticker, |
| timeframe=timeframe, |
| current_price=float(prices.iloc[-1]), |
| last_bar_time=prices.index[-1], |
| raw_signal=raw_sig, |
| sized_position=sized_pos, |
| realized_vol_ann=realized_vol, |
| midband_slope=slope, |
| midband_last=midband_last, |
| bars_used=signal_window_size, |
| ) |
|
|
| |
| midband_vis = None |
| detail_vis = None |
| try: |
| midband_vis = compute_midband_series( |
| pd.Series(log_prices, index=prices.index), |
| wavelet=wavelet, level=level, sig_levels=sig_levels, |
| ) |
| detail_vis = compute_all_details_series( |
| pd.Series(log_prices, index=prices.index), |
| wavelet=wavelet, level=level, |
| ) |
| except Exception as e: |
| warnings.append(f"Full decomposition failed: {e}") |
|
|
| |
| strat_stats = bh_stats = sma_stats = None |
| equity_curve = bh_equity = sma_equity = None |
|
|
| if run_full_backtest: |
| if len(prices) < _MIN_BARS_FOR_BACKTEST: |
| warnings.append( |
| f"Only {len(prices)} bars β backtest requires {_MIN_BARS_FOR_BACKTEST}+. " |
| "Showing signal snapshot only." |
| ) |
| else: |
| try: |
| bt = run_backtest( |
| prices, |
| wavelet=wavelet, |
| level=level, |
| sig_levels=sig_levels, |
| ) |
| strat_stats = bt["strategy_stats"] |
| bh_stats = bt["bh_stats"] |
| sma_stats = bt["sma_stats"] |
| equity_curve = bt["equity_curve"] |
| bh_equity = bt["bh_equity"] |
| sma_equity = bt["sma_equity"] |
| except Exception as e: |
| warnings.append(f"Backtest failed: {e}") |
|
|
| return WaveletAnalysis( |
| ticker=ticker, |
| timeframe=timeframe, |
| timestamp=datetime.now(timezone.utc), |
| signal=snapshot, |
| strategy_stats=strat_stats, |
| bh_stats=bh_stats, |
| sma_stats=sma_stats, |
| midband_series=midband_vis, |
| detail_series=detail_vis, |
| equity_curve=equity_curve, |
| bh_equity=bh_equity, |
| sma_equity=sma_equity, |
| warnings=warnings, |
| ) |
|
|