import numpy as np import pandas as pd import json import os from config import logger, OUTPUT_DIR, Color class NormalInverseWishartUpdater: """ Maintains a Normal-Inverse-Wishart (NIW) prior over the expected returns and covariance. Allows for online, incremental updates of beliefs without full ML model retraining. """ def __init__(self, tickers, initial_mu=None, initial_cov=None, kappa_0=10.0, nu_0=None): self.tickers = tickers self.d = len(tickers) self.kappa = float(kappa_0) self.nu = float(nu_0 if nu_0 is not None else self.d + 2) if initial_mu is not None: self.mu = initial_mu.copy() else: self.mu = np.zeros(self.d) if initial_cov is not None: # Scale matrix Lambda = nu * Sigma self.Lambda = initial_cov * self.nu else: self.Lambda = np.eye(self.d) * self.nu self.state_path = os.path.join(OUTPUT_DIR, "niw_prior_state.json") def save_state(self): state = { 'tickers': self.tickers, 'kappa': self.kappa, 'nu': self.nu, 'mu': self.mu.tolist(), 'Lambda': self.Lambda.tolist() } with open(self.state_path, 'w') as f: json.dump(state, f) def load_state(self): if os.path.exists(self.state_path): try: with open(self.state_path, 'r') as f: state = json.load(f) if state['tickers'] == self.tickers: self.kappa = state['kappa'] self.nu = state['nu'] self.mu = np.array(state['mu']) self.Lambda = np.array(state['Lambda']) return True except Exception as e: logger.warning(f"Failed to load NIW state: {e}") return False def update(self, x_new): """ Updates the NIW parameters given a new observation x_new (1D array of returns). """ x = np.asarray(x_new) if x.shape[0] != self.d: raise ValueError("Observation dimension mismatch.") # Update formulas for single observation (n=1) kappa_n = self.kappa + 1.0 nu_n = self.nu + 1.0 diff = x - self.mu mu_n = (self.kappa * self.mu + x) / kappa_n # Rank-1 update to the scale matrix Lambda_n = self.Lambda + (self.kappa / kappa_n) * np.outer(diff, diff) # To prevent kappa and nu from growing infinitely (which would freeze the prior), # we can apply a slight exponential decay to the weights (rolling window effect). # We cap kappa and nu at a rolling window equivalent of 252 days. max_memory = 252.0 if kappa_n > max_memory: decay = max_memory / kappa_n kappa_n *= decay nu_n *= decay Lambda_n *= decay self.kappa = kappa_n self.nu = nu_n self.mu = mu_n self.Lambda = Lambda_n def get_posterior(self): """Returns posterior expected returns and covariance matrix.""" # The expected value of the covariance matrix under Inverse-Wishart is Lambda / (nu - d - 1) # Using a safer denominator max(1, nu - d - 1) denom = max(1.0, self.nu - self.d - 1) cov_posterior = self.Lambda / denom return self.mu, cov_posterior def compute_divergence(self, new_mu, new_cov): """ Computes the Mahalanobis distance between the new ML predictions and the NIW prior. """ try: cov_inv = np.linalg.inv(new_cov) except np.linalg.LinAlgError: cov_inv = np.linalg.pinv(new_cov) diff = new_mu - self.mu # Simple Mahalanobis distance dist = np.sqrt(np.dot(diff.T, np.dot(cov_inv, diff))) return float(dist)