File size: 2,609 Bytes
332a79f
 
cce889a
332a79f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
Hamiltonian Monte Carlo (NUTS) for complex pattern discovery.
Uses Pyro with NUTS and provides posterior summaries.
"""
import logging
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
import torch
import numpy as np
import pandas as pd
from typing import Dict, Any, Optional

logger = logging.getLogger(__name__)

class HMCAnalyzer:
    """Runs HMC on a simple regression model to demonstrate advanced inference."""

    def __init__(self):
        self.mcmc = None
        self.trace = None

    def _model(self, x, y=None):
        # Linear regression with unknown noise
        alpha = pyro.sample("alpha", dist.Normal(0, 10))
        beta = pyro.sample("beta", dist.Normal(0, 1))
        sigma = pyro.sample("sigma", dist.HalfNormal(1))
        mu = alpha + beta * x
        with pyro.plate("data", len(x)):
            pyro.sample("obs", dist.Normal(mu, sigma), obs=y)

    def run_inference(self, data: Optional[pd.DataFrame] = None, num_samples: int = 500, warmup: int = 200):
        """
        Run HMC on synthetic or provided data.
        If no data, generate synthetic trend data.
        """
        if data is None:
            # Create synthetic data: a linear trend with noise
            x = torch.linspace(0, 10, 50)
            true_alpha, true_beta, true_sigma = 2.0, -0.3, 0.5
            y = true_alpha + true_beta * x + torch.randn(50) * true_sigma
        else:
            # Assume data has columns 'x' and 'y'
            x = torch.tensor(data['x'].values, dtype=torch.float32)
            y = torch.tensor(data['y'].values, dtype=torch.float32)

        nuts_kernel = NUTS(self._model)
        self.mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup)
        self.mcmc.run(x, y)

        self.trace = self.mcmc.get_samples()
        return self._summary()

    def _summary(self) -> Dict[str, Any]:
        """Return summary statistics of posterior samples."""
        if self.trace is None:
            return {}
        summary = {}
        for key in ['alpha', 'beta', 'sigma']:
            samples = self.trace[key].numpy()
            summary[key] = {
                'mean': float(samples.mean()),
                'std': float(samples.std()),
                'hpd_5': float(np.percentile(samples, 5)),
                'hpd_95': float(np.percentile(samples, 95))
            }
        return summary

    def get_trace_data(self) -> Dict[str, np.ndarray]:
        """Return posterior samples for plotting."""
        if self.trace is None:
            return {}
        return {k: v.numpy() for k, v in self.trace.items()}