Delete advanced_inference.py
Browse files- advanced_inference.py +0 -73
advanced_inference.py
DELETED
|
@@ -1,73 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Hamiltonian Monte Carlo (NUTS) for complex pattern discovery.
|
| 3 |
-
Uses Pyro with NUTS and provides posterior summaries.
|
| 4 |
-
"""
|
| 5 |
-
import logging
|
| 6 |
-
import pyro
|
| 7 |
-
import pyro.distributions as dist
|
| 8 |
-
from pyro.infer import MCMC, NUTS
|
| 9 |
-
import torch
|
| 10 |
-
import numpy as np
|
| 11 |
-
import pandas as pd
|
| 12 |
-
from typing import Dict, Any, Optional
|
| 13 |
-
|
| 14 |
-
logger = logging.getLogger(__name__)
|
| 15 |
-
|
| 16 |
-
class HMCAnalyzer:
|
| 17 |
-
"""Runs HMC on a simple regression model to demonstrate advanced inference."""
|
| 18 |
-
|
| 19 |
-
def __init__(self):
|
| 20 |
-
self.mcmc = None
|
| 21 |
-
self.trace = None
|
| 22 |
-
|
| 23 |
-
def _model(self, x, y=None):
|
| 24 |
-
# Linear regression with unknown noise
|
| 25 |
-
alpha = pyro.sample("alpha", dist.Normal(0, 10))
|
| 26 |
-
beta = pyro.sample("beta", dist.Normal(0, 1))
|
| 27 |
-
sigma = pyro.sample("sigma", dist.HalfNormal(1))
|
| 28 |
-
mu = alpha + beta * x
|
| 29 |
-
with pyro.plate("data", len(x)):
|
| 30 |
-
pyro.sample("obs", dist.Normal(mu, sigma), obs=y)
|
| 31 |
-
|
| 32 |
-
def run_inference(self, data: Optional[pd.DataFrame] = None, num_samples: int = 500, warmup: int = 200):
|
| 33 |
-
"""
|
| 34 |
-
Run HMC on synthetic or provided data.
|
| 35 |
-
If no data, generate synthetic trend data.
|
| 36 |
-
"""
|
| 37 |
-
if data is None:
|
| 38 |
-
# Create synthetic data: a linear trend with noise
|
| 39 |
-
x = torch.linspace(0, 10, 50)
|
| 40 |
-
true_alpha, true_beta, true_sigma = 2.0, -0.3, 0.5
|
| 41 |
-
y = true_alpha + true_beta * x + torch.randn(50) * true_sigma
|
| 42 |
-
else:
|
| 43 |
-
# Assume data has columns 'x' and 'y'
|
| 44 |
-
x = torch.tensor(data['x'].values, dtype=torch.float32)
|
| 45 |
-
y = torch.tensor(data['y'].values, dtype=torch.float32)
|
| 46 |
-
|
| 47 |
-
nuts_kernel = NUTS(self._model)
|
| 48 |
-
self.mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup)
|
| 49 |
-
self.mcmc.run(x, y)
|
| 50 |
-
|
| 51 |
-
self.trace = self.mcmc.get_samples()
|
| 52 |
-
return self._summary()
|
| 53 |
-
|
| 54 |
-
def _summary(self) -> Dict[str, Any]:
|
| 55 |
-
"""Return summary statistics of posterior samples."""
|
| 56 |
-
if self.trace is None:
|
| 57 |
-
return {}
|
| 58 |
-
summary = {}
|
| 59 |
-
for key in ['alpha', 'beta', 'sigma']:
|
| 60 |
-
samples = self.trace[key].numpy()
|
| 61 |
-
summary[key] = {
|
| 62 |
-
'mean': float(samples.mean()),
|
| 63 |
-
'std': float(samples.std()),
|
| 64 |
-
'hpd_5': float(np.percentile(samples, 5)),
|
| 65 |
-
'hpd_95': float(np.percentile(samples, 95))
|
| 66 |
-
}
|
| 67 |
-
return summary
|
| 68 |
-
|
| 69 |
-
def get_trace_data(self) -> Dict[str, np.ndarray]:
|
| 70 |
-
"""Return posterior samples for plotting."""
|
| 71 |
-
if self.trace is None:
|
| 72 |
-
return {}
|
| 73 |
-
return {k: v.numpy() for k, v in self.trace.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|