File size: 4,611 Bytes
3d17053
7e061fe
76f997b
3d17053
7e061fe
 
 
 
 
76f997b
 
3d17053
7e061fe
3d17053
7e061fe
 
 
 
 
3d17053
7e061fe
 
 
76f997b
7e061fe
3d17053
7e061fe
76f997b
7e061fe
 
76f997b
 
 
3d17053
7e061fe
 
 
 
3d17053
7e061fe
 
 
 
 
 
 
 
 
 
 
76f997b
7e061fe
 
 
 
 
 
 
 
 
3d17053
 
7e061fe
 
 
76f997b
7e061fe
 
 
 
 
 
 
 
 
 
 
76f997b
7e061fe
76f997b
 
7e061fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
Bayesian risk engine with hyperpriors (hierarchical Beta‑binomial).
Uses Pyro for variational inference.
"""
import logging
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive
import torch
import numpy as np
from typing import Dict, Optional, List, Tuple

logger = logging.getLogger(__name__)

class AIRiskEngine:
    """
    Hierarchical Bayesian model for task‑specific risk.
    Each task category has its own Beta parameters, but they share a common hyperprior.
    """

    def __init__(self, num_categories: int = 10):
        self.num_categories = num_categories
        self.category_names = ["chat", "code", "summary", "image", "audio", "iot", "switch", "server", "service", "unknown"]
        self._history: List[Tuple[int, float]] = []  # (category_idx, success)
        self._init_model()

    def _init_model(self):
        # Hyperpriors (Gamma for alpha, beta)
        self.alpha0 = pyro.param("alpha0", torch.tensor(2.0), constraint=dist.constraints.positive)
        self.beta0 = pyro.param("beta0", torch.tensor(2.0), constraint=dist.constraints.positive)
        # Category‑specific parameters
        self.p_alpha = pyro.param("p_alpha", torch.ones(self.num_categories) * 2.0, constraint=dist.constraints.positive)
        self.p_beta = pyro.param("p_beta", torch.ones(self.num_categories) * 2.0, constraint=dist.constraints.positive)

    def model(self, observations=None):
        # Global hyperprior (concentration parameters)
        alpha0 = pyro.sample("alpha0", dist.Gamma(2.0, 1.0))
        beta0 = pyro.sample("beta0", dist.Gamma(2.0, 1.0))

        with pyro.plate("categories", self.num_categories):
            # Category‑specific success probabilities drawn from Beta(alpha0, beta0)
            p = pyro.sample("p", dist.Beta(alpha0, beta0))

        if observations is not None:
            cat_idx = torch.tensor([obs[0] for obs in observations])
            successes = torch.tensor([obs[1] for obs in observations])
            with pyro.plate("data", len(observations)):
                pyro.sample("obs", dist.Bernoulli(p[cat_idx]), obs=successes)

    def guide(self, observations=None):
        # Variational parameters for hyperpriors
        alpha0_q = pyro.param("alpha0_q", torch.tensor(2.0), constraint=dist.constraints.positive)
        beta0_q = pyro.param("beta0_q", torch.tensor(2.0), constraint=dist.constraints.positive)
        pyro.sample("alpha0", dist.Gamma(alpha0_q, 1.0))
        pyro.sample("beta0", dist.Gamma(beta0_q, 1.0))

        with pyro.plate("categories", self.num_categories):
            p_alpha = pyro.param("p_alpha", torch.ones(self.num_categories) * 2.0, constraint=dist.constraints.positive)
            p_beta = pyro.param("p_beta", torch.ones(self.num_categories) * 2.0, constraint=dist.constraints.positive)
            pyro.sample("p", dist.Beta(p_alpha, p_beta))

    def update_outcome(self, category: str, success: bool):
        """Store observation and optionally trigger a learning step."""
        cat_idx = self.category_names.index(category) if category in self.category_names else -1
        if cat_idx == -1:
            logger.warning(f"Unknown category: {category}")
            return
        self._history.append((cat_idx, 1.0 if success else 0.0))
        # Run a few steps of SVI
        if len(self._history) > 5:
            self._run_svi(steps=10)

    def _run_svi(self, steps=50):
        if len(self._history) == 0:
            return
        optimizer = pyro.optim.Adam({"lr": 0.01})
        svi = SVI(self.model, self.guide, optimizer, loss=Trace_ELBO())
        for step in range(steps):
            loss = svi.step(self._history)
            if step % 10 == 0:
                logger.debug(f"SVI step {step}, loss: {loss}")

    def risk_score(self, category: str) -> Dict[str, float]:
        """Return posterior predictive risk metrics for a category."""
        cat_idx = self.category_names.index(category) if category in self.category_names else -1
        if cat_idx == -1 or len(self._history) == 0:
            return {"mean": 0.5, "p5": 0.1, "p50": 0.5, "p95": 0.9}
        # Generate posterior samples for p[cat_idx]
        predictive = Predictive(self.model, guide=self.guide, num_samples=500)
        samples = predictive(self._history)
        p_samples = samples["p"][:, cat_idx].detach().numpy()
        return {
            "mean": float(p_samples.mean()),
            "p5": float(np.percentile(p_samples, 5)),
            "p50": float(np.percentile(p_samples, 50)),
            "p95": float(np.percentile(p_samples, 95))
        }