""" Hamiltonian Monte Carlo (NUTS) for complex pattern discovery. Uses Pyro with NUTS and provides posterior summaries. """ import logging import pyro import pyro.distributions as dist from pyro.infer import MCMC, NUTS import torch import numpy as np import pandas as pd from typing import Dict, Any, Optional logger = logging.getLogger(__name__) class HMCAnalyzer: """Runs HMC on a simple regression model to demonstrate advanced inference.""" def __init__(self): self.mcmc = None self.trace = None def _model(self, x, y=None): # Linear regression with unknown noise alpha = pyro.sample("alpha", dist.Normal(0, 10)) beta = pyro.sample("beta", dist.Normal(0, 1)) sigma = pyro.sample("sigma", dist.HalfNormal(1)) mu = alpha + beta * x with pyro.plate("data", len(x)): pyro.sample("obs", dist.Normal(mu, sigma), obs=y) def run_inference(self, data: Optional[pd.DataFrame] = None, num_samples: int = 500, warmup: int = 200): """ Run HMC on synthetic or provided data. If no data, generate synthetic trend data. """ if data is None: # Create synthetic data: a linear trend with noise x = torch.linspace(0, 10, 50) true_alpha, true_beta, true_sigma = 2.0, -0.3, 0.5 y = true_alpha + true_beta * x + torch.randn(50) * true_sigma else: # Assume data has columns 'x' and 'y' x = torch.tensor(data['x'].values, dtype=torch.float32) y = torch.tensor(data['y'].values, dtype=torch.float32) nuts_kernel = NUTS(self._model) self.mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup) self.mcmc.run(x, y) self.trace = self.mcmc.get_samples() return self._summary() def _summary(self) -> Dict[str, Any]: """Return summary statistics of posterior samples.""" if self.trace is None: return {} summary = {} for key in ['alpha', 'beta', 'sigma']: samples = self.trace[key].numpy() summary[key] = { 'mean': float(samples.mean()), 'std': float(samples.std()), 'hpd_5': float(np.percentile(samples, 5)), 'hpd_95': float(np.percentile(samples, 95)) } return summary def get_trace_data(self) -> Dict[str, np.ndarray]: """Return posterior samples for plotting.""" if self.trace is None: return {} return {k: v.numpy() for k, v in self.trace.items()}