petter2025 commited on
Commit
76f997b
·
verified ·
1 Parent(s): a3ed1dc

Update ai_risk_engine.py

Browse files
Files changed (1) hide show
  1. ai_risk_engine.py +13 -9
ai_risk_engine.py CHANGED
@@ -1,12 +1,14 @@
1
  """
2
  Bayesian risk engine with hyperpriors (hierarchical Beta‑binomial).
 
3
  """
4
  import logging
5
  import pyro
6
  import pyro.distributions as dist
7
  from pyro.infer import SVI, Trace_ELBO, Predictive
8
  import torch
9
- from typing import Dict, Optional
 
10
 
11
  logger = logging.getLogger(__name__)
12
 
@@ -19,14 +21,16 @@ class AIRiskEngine:
19
  def __init__(self, num_categories: int = 10):
20
  self.num_categories = num_categories
21
  self.category_names = ["chat", "code", "summary", "image", "audio", "iot", "switch", "server", "service", "unknown"]
 
22
  self._init_model()
23
- self._history = [] # store (category, success) for later updates
24
 
25
  def _init_model(self):
26
- # Hyperpriors
27
  self.alpha0 = pyro.param("alpha0", torch.tensor(2.0), constraint=dist.constraints.positive)
28
  self.beta0 = pyro.param("beta0", torch.tensor(2.0), constraint=dist.constraints.positive)
29
- # We'll learn these via SVI when update is called
 
 
30
 
31
  def model(self, observations=None):
32
  # Global hyperprior (concentration parameters)
@@ -38,14 +42,13 @@ class AIRiskEngine:
38
  p = pyro.sample("p", dist.Beta(alpha0, beta0))
39
 
40
  if observations is not None:
41
- # Observations: list of (category_idx, success)
42
  cat_idx = torch.tensor([obs[0] for obs in observations])
43
  successes = torch.tensor([obs[1] for obs in observations])
44
  with pyro.plate("data", len(observations)):
45
  pyro.sample("obs", dist.Bernoulli(p[cat_idx]), obs=successes)
46
 
47
  def guide(self, observations=None):
48
- # Variational parameters
49
  alpha0_q = pyro.param("alpha0_q", torch.tensor(2.0), constraint=dist.constraints.positive)
50
  beta0_q = pyro.param("beta0_q", torch.tensor(2.0), constraint=dist.constraints.positive)
51
  pyro.sample("alpha0", dist.Gamma(alpha0_q, 1.0))
@@ -60,6 +63,7 @@ class AIRiskEngine:
60
  """Store observation and optionally trigger a learning step."""
61
  cat_idx = self.category_names.index(category) if category in self.category_names else -1
62
  if cat_idx == -1:
 
63
  return
64
  self._history.append((cat_idx, 1.0 if success else 0.0))
65
  # Run a few steps of SVI
@@ -71,10 +75,10 @@ class AIRiskEngine:
71
  return
72
  optimizer = pyro.optim.Adam({"lr": 0.01})
73
  svi = SVI(self.model, self.guide, optimizer, loss=Trace_ELBO())
74
- for _ in range(steps):
75
  loss = svi.step(self._history)
76
- if steps % 10 == 0:
77
- logger.debug(f"SVI loss: {loss}")
78
 
79
  def risk_score(self, category: str) -> Dict[str, float]:
80
  """Return posterior predictive risk metrics for a category."""
 
1
  """
2
  Bayesian risk engine with hyperpriors (hierarchical Beta‑binomial).
3
+ Uses Pyro for variational inference.
4
  """
5
  import logging
6
  import pyro
7
  import pyro.distributions as dist
8
  from pyro.infer import SVI, Trace_ELBO, Predictive
9
  import torch
10
+ import numpy as np
11
+ from typing import Dict, Optional, List, Tuple
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
21
  def __init__(self, num_categories: int = 10):
22
  self.num_categories = num_categories
23
  self.category_names = ["chat", "code", "summary", "image", "audio", "iot", "switch", "server", "service", "unknown"]
24
+ self._history: List[Tuple[int, float]] = [] # (category_idx, success)
25
  self._init_model()
 
26
 
27
  def _init_model(self):
28
+ # Hyperpriors (Gamma for alpha, beta)
29
  self.alpha0 = pyro.param("alpha0", torch.tensor(2.0), constraint=dist.constraints.positive)
30
  self.beta0 = pyro.param("beta0", torch.tensor(2.0), constraint=dist.constraints.positive)
31
+ # Category‑specific parameters
32
+ self.p_alpha = pyro.param("p_alpha", torch.ones(self.num_categories) * 2.0, constraint=dist.constraints.positive)
33
+ self.p_beta = pyro.param("p_beta", torch.ones(self.num_categories) * 2.0, constraint=dist.constraints.positive)
34
 
35
  def model(self, observations=None):
36
  # Global hyperprior (concentration parameters)
 
42
  p = pyro.sample("p", dist.Beta(alpha0, beta0))
43
 
44
  if observations is not None:
 
45
  cat_idx = torch.tensor([obs[0] for obs in observations])
46
  successes = torch.tensor([obs[1] for obs in observations])
47
  with pyro.plate("data", len(observations)):
48
  pyro.sample("obs", dist.Bernoulli(p[cat_idx]), obs=successes)
49
 
50
  def guide(self, observations=None):
51
+ # Variational parameters for hyperpriors
52
  alpha0_q = pyro.param("alpha0_q", torch.tensor(2.0), constraint=dist.constraints.positive)
53
  beta0_q = pyro.param("beta0_q", torch.tensor(2.0), constraint=dist.constraints.positive)
54
  pyro.sample("alpha0", dist.Gamma(alpha0_q, 1.0))
 
63
  """Store observation and optionally trigger a learning step."""
64
  cat_idx = self.category_names.index(category) if category in self.category_names else -1
65
  if cat_idx == -1:
66
+ logger.warning(f"Unknown category: {category}")
67
  return
68
  self._history.append((cat_idx, 1.0 if success else 0.0))
69
  # Run a few steps of SVI
 
75
  return
76
  optimizer = pyro.optim.Adam({"lr": 0.01})
77
  svi = SVI(self.model, self.guide, optimizer, loss=Trace_ELBO())
78
+ for step in range(steps):
79
  loss = svi.step(self._history)
80
+ if step % 10 == 0:
81
+ logger.debug(f"SVI step {step}, loss: {loss}")
82
 
83
  def risk_score(self, category: str) -> Dict[str, float]:
84
  """Return posterior predictive risk metrics for a category."""