Hacktrix-121 commited on
Commit
abacc33
·
1 Parent(s): d5d74e8

minor fix

Browse files
rl_agent.py CHANGED
@@ -650,13 +650,16 @@ class PPOTrainer:
650
  # ------------------------------------------------------------------
651
 
652
  def act(self, obs) -> Any:
653
- """Greedy action for deployment / evaluation."""
654
  from adaptive_alert_triage.models import Action
655
  if not obs.alerts:
656
  raise ValueError("No alerts")
657
  s = encode_state(obs)
658
  probs, _ = self.net.forward(s)
659
- a = int(np.argmax(probs))
 
 
 
660
  alert = _select_alert(obs, a)
661
  return Action(alert_id=alert.id, action_type=_ACTION_NAMES[a])
662
 
 
650
  # ------------------------------------------------------------------
651
 
652
  def act(self, obs) -> Any:
653
+ """Stochastic action matching training behavior."""
654
  from adaptive_alert_triage.models import Action
655
  if not obs.alerts:
656
  raise ValueError("No alerts")
657
  s = encode_state(obs)
658
  probs, _ = self.net.forward(s)
659
+ # Sample from policy distribution (same as training), NOT argmax!
660
+ # argmax collapses a learned distribution like [0.35, 0.25, 0.22, 0.18]
661
+ # into always picking the same action.
662
+ a = int(np.random.choice(4, p=probs))
663
  alert = _select_alert(obs, a)
664
  return Action(alert_id=alert.id, action_type=_ACTION_NAMES[a])
665
 
src/adaptive_alert_triage/server.py CHANGED
@@ -485,7 +485,11 @@ async def recommend():
485
  ppo.net.h, ppo.net.c = old_h, old_c
486
  # -----------------------------------------------------------------
487
 
488
- idx = int(np.argmax(probs))
 
 
 
 
489
  act = _ACTION_NAMES[idx]
490
  conf = round(float(probs[idx]) * 100, 1)
491
  return {
 
485
  ppo.net.h, ppo.net.c = old_h, old_c
486
  # -----------------------------------------------------------------
487
 
488
+ # CRITICAL: Use sampling (same as training), NOT argmax!
489
+ # argmax always picks the single highest prob, collapsing a
490
+ # balanced policy like [0.35, 0.25, 0.22, 0.18] into "always
491
+ # INVESTIGATE". Sampling reproduces the trained behavior.
492
+ idx = int(np.random.choice(4, p=probs))
493
  act = _ACTION_NAMES[idx]
494
  conf = round(float(probs[idx]) * 100, 1)
495
  return {