quanthedge / backend /app /services /ml /hmm_regime.py
jashdoshi77's picture
QuantHedge: Full deployment with Docker + nginx + uvicorn
9d29748
"""
Hidden Markov Model Regime Detector.
Fits a Gaussian HMM with 3 hidden states (bull, bear, high-volatility)
on daily returns and realized volatility. Classifies the current market
regime and outputs state probabilities, transition matrix, and history.
Design decisions:
- In-memory model cache with 6-hour TTL β€” auto-retrains on fresh data
- No disk persistence β€” HuggingFace / serverless friendly
- 3 regimes mapped by sorting on mean return (highest=bull, lowest=bear,
middle=high-volatility if vol is highest, else sideways)
- Uses log returns + rolling vol as observation features
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
logger = logging.getLogger(__name__)
# ── Cache ────────────────────────────────────────────────────────────────
CACHE_TTL_SECONDS = 6 * 3600 # 6 hours
@dataclass
class CachedHMM:
"""Cached HMM model with metadata."""
model: Any
label_map: Dict[int, str]
trained_at: float = field(default_factory=time.time)
@property
def is_stale(self) -> bool:
return (time.time() - self.trained_at) > CACHE_TTL_SECONDS
_hmm_cache: Dict[str, CachedHMM] = {}
# ── Regime Labels ────────────────────────────────────────────────────────
REGIME_LABELS = {
"bull": {"emoji": "πŸ‚", "color": "#22c55e", "description": "Uptrend β€” positive returns, moderate volatility"},
"bear": {"emoji": "🐻", "color": "#ef4444", "description": "Downtrend β€” negative returns, rising volatility"},
"high_volatility": {"emoji": "⚑", "color": "#f59e0b", "description": "Choppy β€” mixed returns, elevated volatility"},
}
# ── Feature Construction ─────────────────────────────────────────────────
def _build_observations(df: pd.DataFrame) -> Tuple[np.ndarray, pd.DatetimeIndex]:
"""
Build observation matrix for HMM from OHLCV data.
Features:
1. Log return (daily)
2. 5-day rolling volatility (annualised)
3. 10-day rolling volatility (annualised)
"""
close = df["Close"]
log_ret = np.log(close / close.shift(1))
vol_5 = log_ret.rolling(5).std() * np.sqrt(252)
vol_10 = log_ret.rolling(10).std() * np.sqrt(252)
obs_df = pd.DataFrame({
"log_return": log_ret,
"vol_5d": vol_5,
"vol_10d": vol_10,
}).dropna()
return obs_df.values, obs_df.index
def _map_states_to_regimes(model: Any, n_states: int = 3) -> Dict[int, str]:
"""
Map HMM hidden states to named regimes by analysing means.
Logic:
- Sort states by mean return (first observation feature)
- Highest mean return β†’ bull
- Lowest mean return β†’ bear
- Middle β†’ high_volatility (typically widest variance)
"""
means = model.means_[:, 0] # Mean log return per state
sorted_indices = np.argsort(means)
label_map = {
int(sorted_indices[0]): "bear",
int(sorted_indices[1]): "high_volatility",
int(sorted_indices[2]): "bull",
}
logger.info(
"HMM regime mapping: %s (means: %s)",
label_map,
{int(i): round(float(means[i]), 6) for i in range(n_states)},
)
return label_map
# ── Training ─────────────────────────────────────────────────────────────
def _train_hmm(
df: pd.DataFrame,
n_states: int = 3,
n_iter: int = 100,
) -> Tuple[Any, Dict[int, str]]:
"""
Fit a GaussianHMM on historical data.
Returns (fitted_model, state_label_map).
"""
from hmmlearn.hmm import GaussianHMM
observations, dates = _build_observations(df)
if len(observations) < 60:
raise ValueError(f"Insufficient data: {len(observations)} obs (need 60+)")
model = GaussianHMM(
n_components=n_states,
covariance_type="full",
n_iter=n_iter,
random_state=42,
tol=0.01,
)
model.fit(observations)
label_map = _map_states_to_regimes(model, n_states)
logger.info(
"HMM trained: %d observations, %d states, converged=%s",
len(observations), n_states, model.monitor_.converged,
)
return model, label_map
# ── Prediction ───────────────────────────────────────────────────────────
async def detect_regime(
ticker: str,
period: str = "2y",
n_states: int = 3,
history_days: int = 60,
) -> Dict[str, Any]:
"""
Detect the current market regime for a ticker.
Uses cached model if < 6 hours old, otherwise retrains on fresh data.
Returns:
- current_regime: "bull", "bear", or "high_volatility"
- regime_probabilities: probability of each regime
- regime_history: last N days with regime labels
- transition_matrix: NΓ—N state transition probabilities
- regime_metadata: descriptions, colours, emojis
"""
from app.services.data_ingestion.yahoo import yahoo_adapter
cache_key = f"{ticker}_{period}_{n_states}"
cached = _hmm_cache.get(cache_key)
used_cache = False
if cached and not cached.is_stale:
model = cached.model
label_map = cached.label_map
used_cache = True
logger.info("Using cached HMM for %s (age: %.0fs)", ticker,
time.time() - cached.trained_at)
else:
# Fetch fresh data and train
df = await yahoo_adapter.get_price_dataframe(ticker, period=period)
if df.empty or len(df) < 60:
raise ValueError(f"Insufficient price data for {ticker}")
model, label_map = _train_hmm(df, n_states=n_states)
_hmm_cache[cache_key] = CachedHMM(
model=model,
label_map=label_map,
)
# Re-fetch latest data for scoring (could be newer than training data)
df_latest = await yahoo_adapter.get_price_dataframe(ticker, period=period)
observations, dates = _build_observations(df_latest)
if len(observations) < 10:
raise ValueError(f"Insufficient observation data for {ticker}")
# Decode most likely state sequence (Viterbi)
hidden_states = model.predict(observations)
# Current state probabilities
state_probs = model.predict_proba(observations)
current_probs = state_probs[-1]
# Current regime
current_state = int(hidden_states[-1])
current_regime = label_map.get(current_state, "unknown")
# Regime probabilities (named)
regime_probabilities = {}
for state_idx, label in label_map.items():
regime_probabilities[label] = round(float(current_probs[state_idx]), 4)
# Transition matrix (named)
transition_matrix = {}
for from_state, from_label in label_map.items():
row = {}
for to_state, to_label in label_map.items():
row[to_label] = round(float(model.transmat_[from_state, to_state]), 4)
transition_matrix[from_label] = row
# Regime history (last N days)
history_slice = min(history_days, len(hidden_states))
regime_history = []
for i in range(-history_slice, 0):
idx = len(hidden_states) + i
date_str = (
dates[idx].strftime("%Y-%m-%d")
if hasattr(dates[idx], "strftime")
else str(dates[idx])
)
state = int(hidden_states[idx])
regime_history.append({
"date": date_str,
"regime": label_map.get(state, "unknown"),
"probabilities": {
label_map.get(j, "unknown"): round(float(state_probs[idx][j]), 4)
for j in range(n_states)
},
})
# Regime statistics
regime_stats = {}
close = df_latest["Close"].values
log_returns = np.log(close[1:] / close[:-1])
# Align returns with hidden_states (observations are shorter due to dropna)
align_len = min(len(log_returns), len(hidden_states))
aligned_returns = log_returns[-align_len:]
aligned_states = hidden_states[-align_len:]
for state_idx, label in label_map.items():
mask = aligned_states == state_idx
if mask.sum() > 0:
state_returns = aligned_returns[mask]
regime_stats[label] = {
"mean_daily_return_pct": round(float(np.mean(state_returns) * 100), 4),
"daily_volatility_pct": round(float(np.std(state_returns) * 100), 4),
"days_in_regime": int(mask.sum()),
"pct_of_time": round(float(mask.mean() * 100), 2),
}
return {
"ticker": ticker,
"current_regime": current_regime,
"regime_info": REGIME_LABELS.get(current_regime, {}),
"regime_probabilities": regime_probabilities,
"transition_matrix": transition_matrix,
"regime_history": regime_history,
"regime_stats": regime_stats,
"from_cache": used_cache,
"total_observations": len(observations),
}
def clear_cache(ticker: Optional[str] = None) -> int:
"""Clear HMM cache. Returns number of entries cleared."""
if ticker:
keys = [k for k in _hmm_cache if k.startswith(ticker)]
for k in keys:
del _hmm_cache[k]
return len(keys)
else:
count = len(_hmm_cache)
_hmm_cache.clear()
return count