petter2025's picture
Update ai_risk_engine.py
76f997b verified
raw
history blame
4.61 kB
"""
Bayesian risk engine with hyperpriors (hierarchical Beta‑binomial).
Uses Pyro for variational inference.
"""
import logging
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive
import torch
import numpy as np
from typing import Dict, Optional, List, Tuple
logger = logging.getLogger(__name__)
class AIRiskEngine:
"""
Hierarchical Bayesian model for task‑specific risk.
Each task category has its own Beta parameters, but they share a common hyperprior.
"""
def __init__(self, num_categories: int = 10):
self.num_categories = num_categories
self.category_names = ["chat", "code", "summary", "image", "audio", "iot", "switch", "server", "service", "unknown"]
self._history: List[Tuple[int, float]] = [] # (category_idx, success)
self._init_model()
def _init_model(self):
# Hyperpriors (Gamma for alpha, beta)
self.alpha0 = pyro.param("alpha0", torch.tensor(2.0), constraint=dist.constraints.positive)
self.beta0 = pyro.param("beta0", torch.tensor(2.0), constraint=dist.constraints.positive)
# Category‑specific parameters
self.p_alpha = pyro.param("p_alpha", torch.ones(self.num_categories) * 2.0, constraint=dist.constraints.positive)
self.p_beta = pyro.param("p_beta", torch.ones(self.num_categories) * 2.0, constraint=dist.constraints.positive)
def model(self, observations=None):
# Global hyperprior (concentration parameters)
alpha0 = pyro.sample("alpha0", dist.Gamma(2.0, 1.0))
beta0 = pyro.sample("beta0", dist.Gamma(2.0, 1.0))
with pyro.plate("categories", self.num_categories):
# Category‑specific success probabilities drawn from Beta(alpha0, beta0)
p = pyro.sample("p", dist.Beta(alpha0, beta0))
if observations is not None:
cat_idx = torch.tensor([obs[0] for obs in observations])
successes = torch.tensor([obs[1] for obs in observations])
with pyro.plate("data", len(observations)):
pyro.sample("obs", dist.Bernoulli(p[cat_idx]), obs=successes)
def guide(self, observations=None):
# Variational parameters for hyperpriors
alpha0_q = pyro.param("alpha0_q", torch.tensor(2.0), constraint=dist.constraints.positive)
beta0_q = pyro.param("beta0_q", torch.tensor(2.0), constraint=dist.constraints.positive)
pyro.sample("alpha0", dist.Gamma(alpha0_q, 1.0))
pyro.sample("beta0", dist.Gamma(beta0_q, 1.0))
with pyro.plate("categories", self.num_categories):
p_alpha = pyro.param("p_alpha", torch.ones(self.num_categories) * 2.0, constraint=dist.constraints.positive)
p_beta = pyro.param("p_beta", torch.ones(self.num_categories) * 2.0, constraint=dist.constraints.positive)
pyro.sample("p", dist.Beta(p_alpha, p_beta))
def update_outcome(self, category: str, success: bool):
"""Store observation and optionally trigger a learning step."""
cat_idx = self.category_names.index(category) if category in self.category_names else -1
if cat_idx == -1:
logger.warning(f"Unknown category: {category}")
return
self._history.append((cat_idx, 1.0 if success else 0.0))
# Run a few steps of SVI
if len(self._history) > 5:
self._run_svi(steps=10)
def _run_svi(self, steps=50):
if len(self._history) == 0:
return
optimizer = pyro.optim.Adam({"lr": 0.01})
svi = SVI(self.model, self.guide, optimizer, loss=Trace_ELBO())
for step in range(steps):
loss = svi.step(self._history)
if step % 10 == 0:
logger.debug(f"SVI step {step}, loss: {loss}")
def risk_score(self, category: str) -> Dict[str, float]:
"""Return posterior predictive risk metrics for a category."""
cat_idx = self.category_names.index(category) if category in self.category_names else -1
if cat_idx == -1 or len(self._history) == 0:
return {"mean": 0.5, "p5": 0.1, "p50": 0.5, "p95": 0.9}
# Generate posterior samples for p[cat_idx]
predictive = Predictive(self.model, guide=self.guide, num_samples=500)
samples = predictive(self._history)
p_samples = samples["p"][:, cat_idx].detach().numpy()
return {
"mean": float(p_samples.mean()),
"p5": float(np.percentile(p_samples, 5)),
"p50": float(np.percentile(p_samples, 50)),
"p95": float(np.percentile(p_samples, 95))
}