"""WaveletAnalyzer — main entry point for wavelet analysis. Orchestrates: 1. Data fetching (yfinance, cached) 2. Full MODWT detail decomposition for visualization 3. Walk-forward backtest 4. Signal snapshot (current mid-band signal for the ticker) 5. Returns an Analysis dataclass ready for formatting Usage: from src.transformations.wavelets.analyzer import WaveletAnalyzer result = await WaveletAnalyzer.analyze("SPY", "1d", equity=10_000) """ from __future__ import annotations import asyncio import logging from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Optional import numpy as np import pandas as pd import yfinance as yf from .backtest import run_backtest, LOOKBACK, WAVELET, LEVEL, SIG_LEVELS, SLOPE_WINDOW from .signal import compute_signal, compute_midband_series, compute_all_details_series, volatility_target from .stats import Stats, perf logger = logging.getLogger(__name__) _MIN_BARS_FOR_BACKTEST = LOOKBACK + 20 _MIN_BARS_FOR_SIGNAL = 256 @dataclass class SignalSnapshot: """Current state of the MODWT signal for a single ticker.""" ticker: str timeframe: str current_price: float last_bar_time: pd.Timestamp raw_signal: float # +1.0, -1.0, or 0.0 sized_position: float # after vol targeting realized_vol_ann: float # 60-day annualized realized vol midband_slope: float # the underlying slope value (not just sign) midband_last: float # mid-band value at the last safe point bars_used: int @dataclass class WaveletAnalysis: """Full output of WaveletAnalyzer.analyze().""" ticker: str timeframe: str timestamp: datetime signal: SignalSnapshot # Backtest stats (None if price history too short) strategy_stats: Optional[Stats] = None bh_stats: Optional[Stats] = None sma_stats: Optional[Stats] = None # Mid-band reconstruction series (for visualization, full history) midband_series: Optional[pd.Series] = None # Per-level detail series (for visualization) detail_series: Optional[dict[int, pd.Series]] = None # Backtest equity curves equity_curve: Optional[pd.Series] = None bh_equity: Optional[pd.Series] = None sma_equity: Optional[pd.Series] = None # Warnings warnings: list[str] = field(default_factory=list) def _load_prices(ticker: str, timeframe: str, bars: int) -> pd.Series: """Fetch OHLCV via yfinance and return Close series.""" period_map = { "1d": f"{max(bars * 2, 365)}d", "1wk": f"{max(bars * 8, 365 * 2)}d", } period = period_map.get(timeframe, f"{bars * 2}d") df = yf.download( ticker, period=period, interval=timeframe, auto_adjust=True, progress=False, threads=False, ) if df.empty: raise ValueError(f"No data returned for {ticker} ({timeframe})") if isinstance(df.columns, pd.MultiIndex): df.columns = df.columns.get_level_values(0) close = df["Close"].dropna() close.name = ticker return close.tail(bars) def _compute_midband_slope(midband_safe: np.ndarray, slope_window: int) -> float: if len(midband_safe) < slope_window: return 0.0 window = midband_safe[-slope_window:] n = len(window) x = np.arange(n, dtype=float) - (n - 1) / 2 denom = (x * x).sum() return float((x * (window - window.mean())).sum() / denom) if denom != 0 else 0.0 class WaveletAnalyzer: """Stateless wavelet analysis for a ticker + timeframe.""" @staticmethod async def analyze( ticker: str, timeframe: str = "1d", equity: float = 10_000.0, run_full_backtest: bool = True, wavelet: str = WAVELET, level: int = LEVEL, sig_levels: list[int] | None = None, ) -> WaveletAnalysis: """Full analysis: signal snapshot + optional walk-forward backtest. Runs in executor to avoid blocking the async event loop. """ loop = asyncio.get_event_loop() return await loop.run_in_executor( None, lambda: WaveletAnalyzer._analyze_sync( ticker, timeframe, equity, run_full_backtest, wavelet, level, sig_levels ), ) @staticmethod def _analyze_sync( ticker: str, timeframe: str, equity: float, run_full_backtest: bool, wavelet: str, level: int, sig_levels: list[int] | None, ) -> WaveletAnalysis: if sig_levels is None: sig_levels = SIG_LEVELS warnings: list[str] = [] max_sig_level = max(sig_levels) trim = 2 ** (max_sig_level - 1) # ── 1. Fetch data ───────────────────────────────────────────────────── bars_to_fetch = _MIN_BARS_FOR_BACKTEST + 50 if run_full_backtest else _MIN_BARS_FOR_SIGNAL + 50 try: prices = _load_prices(ticker, timeframe, bars_to_fetch) except Exception as e: raise ValueError(f"Data fetch failed for {ticker}: {e}") from e if len(prices) < _MIN_BARS_FOR_SIGNAL: raise ValueError( f"Only {len(prices)} bars available for {ticker} — need {_MIN_BARS_FOR_SIGNAL}+" ) log_prices = np.log(prices.values) daily_ret = np.diff(log_prices, prepend=log_prices[0]) # ── 2. Current signal snapshot ──────────────────────────────────────── signal_window_size = min(LOOKBACK, len(prices)) signal_window = log_prices[-signal_window_size:] raw_sig = compute_signal( signal_window, wavelet=wavelet, level=level, sig_levels=sig_levels, slope_window=SLOPE_WINDOW, ) # Realized vol from past 60 bars vol_window = min(60, len(daily_ret) - 1) realized_vol = float(np.std(daily_ret[-vol_window:]) * np.sqrt(252)) sized_pos = volatility_target(raw_sig, realized_vol) # Mid-band values in the signal window for slope + last value from .modwt import modwt_details_causal, reconstruct_midband, trim_boundary try: details = modwt_details_causal(signal_window, wavelet=wavelet, level=level) midband = reconstruct_midband(details, sig_levels) safe_midband = trim_boundary(midband, max_level=max_sig_level) except Exception: safe_midband = np.zeros(1) slope = _compute_midband_slope(safe_midband, SLOPE_WINDOW) midband_last = float(safe_midband[-1]) if len(safe_midband) > 0 else 0.0 snapshot = SignalSnapshot( ticker=ticker, timeframe=timeframe, current_price=float(prices.iloc[-1]), last_bar_time=prices.index[-1], raw_signal=raw_sig, sized_position=sized_pos, realized_vol_ann=realized_vol, midband_slope=slope, midband_last=midband_last, bars_used=signal_window_size, ) # ── 3. Full decomposition for visualization ─────────────────────────── midband_vis = None detail_vis = None try: midband_vis = compute_midband_series( pd.Series(log_prices, index=prices.index), wavelet=wavelet, level=level, sig_levels=sig_levels, ) detail_vis = compute_all_details_series( pd.Series(log_prices, index=prices.index), wavelet=wavelet, level=level, ) except Exception as e: warnings.append(f"Full decomposition failed: {e}") # ── 4. Walk-forward backtest ────────────────────────────────────────── strat_stats = bh_stats = sma_stats = None equity_curve = bh_equity = sma_equity = None if run_full_backtest: if len(prices) < _MIN_BARS_FOR_BACKTEST: warnings.append( f"Only {len(prices)} bars — backtest requires {_MIN_BARS_FOR_BACKTEST}+. " "Showing signal snapshot only." ) else: try: bt = run_backtest( prices, wavelet=wavelet, level=level, sig_levels=sig_levels, ) strat_stats = bt["strategy_stats"] bh_stats = bt["bh_stats"] sma_stats = bt["sma_stats"] equity_curve = bt["equity_curve"] bh_equity = bt["bh_equity"] sma_equity = bt["sma_equity"] except Exception as e: warnings.append(f"Backtest failed: {e}") return WaveletAnalysis( ticker=ticker, timeframe=timeframe, timestamp=datetime.now(timezone.utc), signal=snapshot, strategy_stats=strat_stats, bh_stats=bh_stats, sma_stats=sma_stats, midband_series=midband_vis, detail_series=detail_vis, equity_curve=equity_curve, bh_equity=bh_equity, sma_equity=sma_equity, warnings=warnings, )