Agentic-Reliability-Framework-v4 / advanced_inference.py
petter2025's picture
Update advanced_inference.py
cce889a verified
raw
history blame
2.61 kB
"""
Hamiltonian Monte Carlo (NUTS) for complex pattern discovery.
Uses Pyro with NUTS and provides posterior summaries.
"""
import logging
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
import torch
import numpy as np
import pandas as pd
from typing import Dict, Any, Optional
logger = logging.getLogger(__name__)
class HMCAnalyzer:
"""Runs HMC on a simple regression model to demonstrate advanced inference."""
def __init__(self):
self.mcmc = None
self.trace = None
def _model(self, x, y=None):
# Linear regression with unknown noise
alpha = pyro.sample("alpha", dist.Normal(0, 10))
beta = pyro.sample("beta", dist.Normal(0, 1))
sigma = pyro.sample("sigma", dist.HalfNormal(1))
mu = alpha + beta * x
with pyro.plate("data", len(x)):
pyro.sample("obs", dist.Normal(mu, sigma), obs=y)
def run_inference(self, data: Optional[pd.DataFrame] = None, num_samples: int = 500, warmup: int = 200):
"""
Run HMC on synthetic or provided data.
If no data, generate synthetic trend data.
"""
if data is None:
# Create synthetic data: a linear trend with noise
x = torch.linspace(0, 10, 50)
true_alpha, true_beta, true_sigma = 2.0, -0.3, 0.5
y = true_alpha + true_beta * x + torch.randn(50) * true_sigma
else:
# Assume data has columns 'x' and 'y'
x = torch.tensor(data['x'].values, dtype=torch.float32)
y = torch.tensor(data['y'].values, dtype=torch.float32)
nuts_kernel = NUTS(self._model)
self.mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup)
self.mcmc.run(x, y)
self.trace = self.mcmc.get_samples()
return self._summary()
def _summary(self) -> Dict[str, Any]:
"""Return summary statistics of posterior samples."""
if self.trace is None:
return {}
summary = {}
for key in ['alpha', 'beta', 'sigma']:
samples = self.trace[key].numpy()
summary[key] = {
'mean': float(samples.mean()),
'std': float(samples.std()),
'hpd_5': float(np.percentile(samples, 5)),
'hpd_95': float(np.percentile(samples, 95))
}
return summary
def get_trace_data(self) -> Dict[str, np.ndarray]:
"""Return posterior samples for plotting."""
if self.trace is None:
return {}
return {k: v.numpy() for k, v in self.trace.items()}