petter2025 commited on
Commit
f195c5e
·
verified ·
1 Parent(s): bb7b3c2

Delete advanced_inference.py

Browse files
Files changed (1) hide show
  1. advanced_inference.py +0 -73
advanced_inference.py DELETED
@@ -1,73 +0,0 @@
1
- """
2
- Hamiltonian Monte Carlo (NUTS) for complex pattern discovery.
3
- Uses Pyro with NUTS and provides posterior summaries.
4
- """
5
- import logging
6
- import pyro
7
- import pyro.distributions as dist
8
- from pyro.infer import MCMC, NUTS
9
- import torch
10
- import numpy as np
11
- import pandas as pd
12
- from typing import Dict, Any, Optional
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
- class HMCAnalyzer:
17
- """Runs HMC on a simple regression model to demonstrate advanced inference."""
18
-
19
- def __init__(self):
20
- self.mcmc = None
21
- self.trace = None
22
-
23
- def _model(self, x, y=None):
24
- # Linear regression with unknown noise
25
- alpha = pyro.sample("alpha", dist.Normal(0, 10))
26
- beta = pyro.sample("beta", dist.Normal(0, 1))
27
- sigma = pyro.sample("sigma", dist.HalfNormal(1))
28
- mu = alpha + beta * x
29
- with pyro.plate("data", len(x)):
30
- pyro.sample("obs", dist.Normal(mu, sigma), obs=y)
31
-
32
- def run_inference(self, data: Optional[pd.DataFrame] = None, num_samples: int = 500, warmup: int = 200):
33
- """
34
- Run HMC on synthetic or provided data.
35
- If no data, generate synthetic trend data.
36
- """
37
- if data is None:
38
- # Create synthetic data: a linear trend with noise
39
- x = torch.linspace(0, 10, 50)
40
- true_alpha, true_beta, true_sigma = 2.0, -0.3, 0.5
41
- y = true_alpha + true_beta * x + torch.randn(50) * true_sigma
42
- else:
43
- # Assume data has columns 'x' and 'y'
44
- x = torch.tensor(data['x'].values, dtype=torch.float32)
45
- y = torch.tensor(data['y'].values, dtype=torch.float32)
46
-
47
- nuts_kernel = NUTS(self._model)
48
- self.mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup)
49
- self.mcmc.run(x, y)
50
-
51
- self.trace = self.mcmc.get_samples()
52
- return self._summary()
53
-
54
- def _summary(self) -> Dict[str, Any]:
55
- """Return summary statistics of posterior samples."""
56
- if self.trace is None:
57
- return {}
58
- summary = {}
59
- for key in ['alpha', 'beta', 'sigma']:
60
- samples = self.trace[key].numpy()
61
- summary[key] = {
62
- 'mean': float(samples.mean()),
63
- 'std': float(samples.std()),
64
- 'hpd_5': float(np.percentile(samples, 5)),
65
- 'hpd_95': float(np.percentile(samples, 95))
66
- }
67
- return summary
68
-
69
- def get_trace_data(self) -> Dict[str, np.ndarray]:
70
- """Return posterior samples for plotting."""
71
- if self.trace is None:
72
- return {}
73
- return {k: v.numpy() for k, v in self.trace.items()}