| """ |
| Hamiltonian Monte Carlo (NUTS) for complex pattern discovery. |
| Uses Pyro with NUTS and provides posterior summaries. |
| """ |
| import logging |
| import pyro |
| import pyro.distributions as dist |
| from pyro.infer import MCMC, NUTS |
| import torch |
| import numpy as np |
| import pandas as pd |
| from typing import Dict, Any, Optional |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class HMCAnalyzer: |
| """Runs HMC on a simple regression model to demonstrate advanced inference.""" |
|
|
| def __init__(self): |
| self.mcmc = None |
| self.trace = None |
|
|
| def _model(self, x, y=None): |
| |
| alpha = pyro.sample("alpha", dist.Normal(0, 10)) |
| beta = pyro.sample("beta", dist.Normal(0, 1)) |
| sigma = pyro.sample("sigma", dist.HalfNormal(1)) |
| mu = alpha + beta * x |
| with pyro.plate("data", len(x)): |
| pyro.sample("obs", dist.Normal(mu, sigma), obs=y) |
|
|
| def run_inference(self, data: Optional[pd.DataFrame] = None, num_samples: int = 500, warmup: int = 200): |
| """ |
| Run HMC on synthetic or provided data. |
| If no data, generate synthetic trend data. |
| """ |
| if data is None: |
| |
| x = torch.linspace(0, 10, 50) |
| true_alpha, true_beta, true_sigma = 2.0, -0.3, 0.5 |
| y = true_alpha + true_beta * x + torch.randn(50) * true_sigma |
| else: |
| |
| x = torch.tensor(data['x'].values, dtype=torch.float32) |
| y = torch.tensor(data['y'].values, dtype=torch.float32) |
|
|
| nuts_kernel = NUTS(self._model) |
| self.mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup) |
| self.mcmc.run(x, y) |
|
|
| self.trace = self.mcmc.get_samples() |
| return self._summary() |
|
|
| def _summary(self) -> Dict[str, Any]: |
| """Return summary statistics of posterior samples.""" |
| if self.trace is None: |
| return {} |
| summary = {} |
| for key in ['alpha', 'beta', 'sigma']: |
| samples = self.trace[key].numpy() |
| summary[key] = { |
| 'mean': float(samples.mean()), |
| 'std': float(samples.std()), |
| 'hpd_5': float(np.percentile(samples, 5)), |
| 'hpd_95': float(np.percentile(samples, 95)) |
| } |
| return summary |
|
|
| def get_trace_data(self) -> Dict[str, np.ndarray]: |
| """Return posterior samples for plotting.""" |
| if self.trace is None: |
| return {} |
| return {k: v.numpy() for k, v in self.trace.items()} |