Spaces:
Running
Running
| """ | |
| Hidden Markov Model Regime Detector. | |
| Fits a Gaussian HMM with 3 hidden states (bull, bear, high-volatility) | |
| on daily returns and realized volatility. Classifies the current market | |
| regime and outputs state probabilities, transition matrix, and history. | |
| Design decisions: | |
| - In-memory model cache with 6-hour TTL β auto-retrains on fresh data | |
| - No disk persistence β HuggingFace / serverless friendly | |
| - 3 regimes mapped by sorting on mean return (highest=bull, lowest=bear, | |
| middle=high-volatility if vol is highest, else sideways) | |
| - Uses log returns + rolling vol as observation features | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import time | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| import pandas as pd | |
| logger = logging.getLogger(__name__) | |
| # ββ Cache ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CACHE_TTL_SECONDS = 6 * 3600 # 6 hours | |
| class CachedHMM: | |
| """Cached HMM model with metadata.""" | |
| model: Any | |
| label_map: Dict[int, str] | |
| trained_at: float = field(default_factory=time.time) | |
| def is_stale(self) -> bool: | |
| return (time.time() - self.trained_at) > CACHE_TTL_SECONDS | |
| _hmm_cache: Dict[str, CachedHMM] = {} | |
| # ββ Regime Labels ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| REGIME_LABELS = { | |
| "bull": {"emoji": "π", "color": "#22c55e", "description": "Uptrend β positive returns, moderate volatility"}, | |
| "bear": {"emoji": "π»", "color": "#ef4444", "description": "Downtrend β negative returns, rising volatility"}, | |
| "high_volatility": {"emoji": "β‘", "color": "#f59e0b", "description": "Choppy β mixed returns, elevated volatility"}, | |
| } | |
| # ββ Feature Construction βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_observations(df: pd.DataFrame) -> Tuple[np.ndarray, pd.DatetimeIndex]: | |
| """ | |
| Build observation matrix for HMM from OHLCV data. | |
| Features: | |
| 1. Log return (daily) | |
| 2. 5-day rolling volatility (annualised) | |
| 3. 10-day rolling volatility (annualised) | |
| """ | |
| close = df["Close"] | |
| log_ret = np.log(close / close.shift(1)) | |
| vol_5 = log_ret.rolling(5).std() * np.sqrt(252) | |
| vol_10 = log_ret.rolling(10).std() * np.sqrt(252) | |
| obs_df = pd.DataFrame({ | |
| "log_return": log_ret, | |
| "vol_5d": vol_5, | |
| "vol_10d": vol_10, | |
| }).dropna() | |
| return obs_df.values, obs_df.index | |
| def _map_states_to_regimes(model: Any, n_states: int = 3) -> Dict[int, str]: | |
| """ | |
| Map HMM hidden states to named regimes by analysing means. | |
| Logic: | |
| - Sort states by mean return (first observation feature) | |
| - Highest mean return β bull | |
| - Lowest mean return β bear | |
| - Middle β high_volatility (typically widest variance) | |
| """ | |
| means = model.means_[:, 0] # Mean log return per state | |
| sorted_indices = np.argsort(means) | |
| label_map = { | |
| int(sorted_indices[0]): "bear", | |
| int(sorted_indices[1]): "high_volatility", | |
| int(sorted_indices[2]): "bull", | |
| } | |
| logger.info( | |
| "HMM regime mapping: %s (means: %s)", | |
| label_map, | |
| {int(i): round(float(means[i]), 6) for i in range(n_states)}, | |
| ) | |
| return label_map | |
| # ββ Training βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _train_hmm( | |
| df: pd.DataFrame, | |
| n_states: int = 3, | |
| n_iter: int = 100, | |
| ) -> Tuple[Any, Dict[int, str]]: | |
| """ | |
| Fit a GaussianHMM on historical data. | |
| Returns (fitted_model, state_label_map). | |
| """ | |
| from hmmlearn.hmm import GaussianHMM | |
| observations, dates = _build_observations(df) | |
| if len(observations) < 60: | |
| raise ValueError(f"Insufficient data: {len(observations)} obs (need 60+)") | |
| model = GaussianHMM( | |
| n_components=n_states, | |
| covariance_type="full", | |
| n_iter=n_iter, | |
| random_state=42, | |
| tol=0.01, | |
| ) | |
| model.fit(observations) | |
| label_map = _map_states_to_regimes(model, n_states) | |
| logger.info( | |
| "HMM trained: %d observations, %d states, converged=%s", | |
| len(observations), n_states, model.monitor_.converged, | |
| ) | |
| return model, label_map | |
| # ββ Prediction βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def detect_regime( | |
| ticker: str, | |
| period: str = "2y", | |
| n_states: int = 3, | |
| history_days: int = 60, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Detect the current market regime for a ticker. | |
| Uses cached model if < 6 hours old, otherwise retrains on fresh data. | |
| Returns: | |
| - current_regime: "bull", "bear", or "high_volatility" | |
| - regime_probabilities: probability of each regime | |
| - regime_history: last N days with regime labels | |
| - transition_matrix: NΓN state transition probabilities | |
| - regime_metadata: descriptions, colours, emojis | |
| """ | |
| from app.services.data_ingestion.yahoo import yahoo_adapter | |
| cache_key = f"{ticker}_{period}_{n_states}" | |
| cached = _hmm_cache.get(cache_key) | |
| used_cache = False | |
| if cached and not cached.is_stale: | |
| model = cached.model | |
| label_map = cached.label_map | |
| used_cache = True | |
| logger.info("Using cached HMM for %s (age: %.0fs)", ticker, | |
| time.time() - cached.trained_at) | |
| else: | |
| # Fetch fresh data and train | |
| df = await yahoo_adapter.get_price_dataframe(ticker, period=period) | |
| if df.empty or len(df) < 60: | |
| raise ValueError(f"Insufficient price data for {ticker}") | |
| model, label_map = _train_hmm(df, n_states=n_states) | |
| _hmm_cache[cache_key] = CachedHMM( | |
| model=model, | |
| label_map=label_map, | |
| ) | |
| # Re-fetch latest data for scoring (could be newer than training data) | |
| df_latest = await yahoo_adapter.get_price_dataframe(ticker, period=period) | |
| observations, dates = _build_observations(df_latest) | |
| if len(observations) < 10: | |
| raise ValueError(f"Insufficient observation data for {ticker}") | |
| # Decode most likely state sequence (Viterbi) | |
| hidden_states = model.predict(observations) | |
| # Current state probabilities | |
| state_probs = model.predict_proba(observations) | |
| current_probs = state_probs[-1] | |
| # Current regime | |
| current_state = int(hidden_states[-1]) | |
| current_regime = label_map.get(current_state, "unknown") | |
| # Regime probabilities (named) | |
| regime_probabilities = {} | |
| for state_idx, label in label_map.items(): | |
| regime_probabilities[label] = round(float(current_probs[state_idx]), 4) | |
| # Transition matrix (named) | |
| transition_matrix = {} | |
| for from_state, from_label in label_map.items(): | |
| row = {} | |
| for to_state, to_label in label_map.items(): | |
| row[to_label] = round(float(model.transmat_[from_state, to_state]), 4) | |
| transition_matrix[from_label] = row | |
| # Regime history (last N days) | |
| history_slice = min(history_days, len(hidden_states)) | |
| regime_history = [] | |
| for i in range(-history_slice, 0): | |
| idx = len(hidden_states) + i | |
| date_str = ( | |
| dates[idx].strftime("%Y-%m-%d") | |
| if hasattr(dates[idx], "strftime") | |
| else str(dates[idx]) | |
| ) | |
| state = int(hidden_states[idx]) | |
| regime_history.append({ | |
| "date": date_str, | |
| "regime": label_map.get(state, "unknown"), | |
| "probabilities": { | |
| label_map.get(j, "unknown"): round(float(state_probs[idx][j]), 4) | |
| for j in range(n_states) | |
| }, | |
| }) | |
| # Regime statistics | |
| regime_stats = {} | |
| close = df_latest["Close"].values | |
| log_returns = np.log(close[1:] / close[:-1]) | |
| # Align returns with hidden_states (observations are shorter due to dropna) | |
| align_len = min(len(log_returns), len(hidden_states)) | |
| aligned_returns = log_returns[-align_len:] | |
| aligned_states = hidden_states[-align_len:] | |
| for state_idx, label in label_map.items(): | |
| mask = aligned_states == state_idx | |
| if mask.sum() > 0: | |
| state_returns = aligned_returns[mask] | |
| regime_stats[label] = { | |
| "mean_daily_return_pct": round(float(np.mean(state_returns) * 100), 4), | |
| "daily_volatility_pct": round(float(np.std(state_returns) * 100), 4), | |
| "days_in_regime": int(mask.sum()), | |
| "pct_of_time": round(float(mask.mean() * 100), 2), | |
| } | |
| return { | |
| "ticker": ticker, | |
| "current_regime": current_regime, | |
| "regime_info": REGIME_LABELS.get(current_regime, {}), | |
| "regime_probabilities": regime_probabilities, | |
| "transition_matrix": transition_matrix, | |
| "regime_history": regime_history, | |
| "regime_stats": regime_stats, | |
| "from_cache": used_cache, | |
| "total_observations": len(observations), | |
| } | |
| def clear_cache(ticker: Optional[str] = None) -> int: | |
| """Clear HMM cache. Returns number of entries cleared.""" | |
| if ticker: | |
| keys = [k for k in _hmm_cache if k.startswith(ticker)] | |
| for k in keys: | |
| del _hmm_cache[k] | |
| return len(keys) | |
| else: | |
| count = len(_hmm_cache) | |
| _hmm_cache.clear() | |
| return count | |