Spaces:
Runtime error
Runtime error
Create advanced_inference.py
Browse files- advanced_inference.py +73 -0
advanced_inference.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hamiltonian Monte Carlo (NUTS) for complex pattern discovery.
|
| 3 |
+
Uses Pyro with NUTS and ArviZ for visualization.
|
| 4 |
+
"""
|
| 5 |
+
import logging
|
| 6 |
+
import pyro
|
| 7 |
+
import pyro.distributions as dist
|
| 8 |
+
from pyro.infer import MCMC, NUTS
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from typing import Dict, Any, Optional
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
class HMCAnalyzer:
|
| 17 |
+
"""Runs HMC on a simple regression model to demonstrate advanced inference."""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.mcmc = None
|
| 21 |
+
self.trace = None
|
| 22 |
+
|
| 23 |
+
def _model(self, x, y=None):
|
| 24 |
+
# Linear regression with unknown noise
|
| 25 |
+
alpha = pyro.sample("alpha", dist.Normal(0, 10))
|
| 26 |
+
beta = pyro.sample("beta", dist.Normal(0, 1))
|
| 27 |
+
sigma = pyro.sample("sigma", dist.HalfNormal(1))
|
| 28 |
+
mu = alpha + beta * x
|
| 29 |
+
with pyro.plate("data", len(x)):
|
| 30 |
+
pyro.sample("obs", dist.Normal(mu, sigma), obs=y)
|
| 31 |
+
|
| 32 |
+
def run_inference(self, data: Optional[pd.DataFrame] = None, num_samples: int = 500, warmup: int = 200):
|
| 33 |
+
"""
|
| 34 |
+
Run HMC on synthetic or provided data.
|
| 35 |
+
If no data, generate synthetic trend data.
|
| 36 |
+
"""
|
| 37 |
+
if data is None:
|
| 38 |
+
# Create synthetic data: a linear trend with noise
|
| 39 |
+
x = torch.linspace(0, 10, 50)
|
| 40 |
+
true_alpha, true_beta, true_sigma = 2.0, -0.3, 0.5
|
| 41 |
+
y = true_alpha + true_beta * x + torch.randn(50) * true_sigma
|
| 42 |
+
else:
|
| 43 |
+
# Assume data has columns 'x' and 'y'
|
| 44 |
+
x = torch.tensor(data['x'].values, dtype=torch.float32)
|
| 45 |
+
y = torch.tensor(data['y'].values, dtype=torch.float32)
|
| 46 |
+
|
| 47 |
+
nuts_kernel = NUTS(self._model)
|
| 48 |
+
self.mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup)
|
| 49 |
+
self.mcmc.run(x, y)
|
| 50 |
+
|
| 51 |
+
self.trace = self.mcmc.get_samples()
|
| 52 |
+
return self._summary()
|
| 53 |
+
|
| 54 |
+
def _summary(self) -> Dict[str, Any]:
|
| 55 |
+
"""Return summary statistics of posterior samples."""
|
| 56 |
+
if self.trace is None:
|
| 57 |
+
return {}
|
| 58 |
+
summary = {}
|
| 59 |
+
for key in ['alpha', 'beta', 'sigma']:
|
| 60 |
+
samples = self.trace[key].numpy()
|
| 61 |
+
summary[key] = {
|
| 62 |
+
'mean': float(samples.mean()),
|
| 63 |
+
'std': float(samples.std()),
|
| 64 |
+
'hpd_5': float(np.percentile(samples, 5)),
|
| 65 |
+
'hpd_95': float(np.percentile(samples, 95))
|
| 66 |
+
}
|
| 67 |
+
return summary
|
| 68 |
+
|
| 69 |
+
def get_trace_data(self) -> Dict[str, np.ndarray]:
|
| 70 |
+
"""Return posterior samples for plotting."""
|
| 71 |
+
if self.trace is None:
|
| 72 |
+
return {}
|
| 73 |
+
return {k: v.numpy() for k, v in self.trace.items()}
|