Spaces:
Sleeping
Sleeping
Commit ·
abacc33
1
Parent(s): d5d74e8
minor fix
Browse files- rl_agent.py +5 -2
- src/adaptive_alert_triage/server.py +5 -1
rl_agent.py
CHANGED
|
@@ -650,13 +650,16 @@ class PPOTrainer:
|
|
| 650 |
# ------------------------------------------------------------------
|
| 651 |
|
| 652 |
def act(self, obs) -> Any:
|
| 653 |
-
"""
|
| 654 |
from adaptive_alert_triage.models import Action
|
| 655 |
if not obs.alerts:
|
| 656 |
raise ValueError("No alerts")
|
| 657 |
s = encode_state(obs)
|
| 658 |
probs, _ = self.net.forward(s)
|
| 659 |
-
|
|
|
|
|
|
|
|
|
|
| 660 |
alert = _select_alert(obs, a)
|
| 661 |
return Action(alert_id=alert.id, action_type=_ACTION_NAMES[a])
|
| 662 |
|
|
|
|
| 650 |
# ------------------------------------------------------------------
|
| 651 |
|
| 652 |
def act(self, obs) -> Any:
|
| 653 |
+
"""Stochastic action matching training behavior."""
|
| 654 |
from adaptive_alert_triage.models import Action
|
| 655 |
if not obs.alerts:
|
| 656 |
raise ValueError("No alerts")
|
| 657 |
s = encode_state(obs)
|
| 658 |
probs, _ = self.net.forward(s)
|
| 659 |
+
# Sample from policy distribution (same as training), NOT argmax!
|
| 660 |
+
# argmax collapses a learned distribution like [0.35, 0.25, 0.22, 0.18]
|
| 661 |
+
# into always picking the same action.
|
| 662 |
+
a = int(np.random.choice(4, p=probs))
|
| 663 |
alert = _select_alert(obs, a)
|
| 664 |
return Action(alert_id=alert.id, action_type=_ACTION_NAMES[a])
|
| 665 |
|
src/adaptive_alert_triage/server.py
CHANGED
|
@@ -485,7 +485,11 @@ async def recommend():
|
|
| 485 |
ppo.net.h, ppo.net.c = old_h, old_c
|
| 486 |
# -----------------------------------------------------------------
|
| 487 |
|
| 488 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
act = _ACTION_NAMES[idx]
|
| 490 |
conf = round(float(probs[idx]) * 100, 1)
|
| 491 |
return {
|
|
|
|
| 485 |
ppo.net.h, ppo.net.c = old_h, old_c
|
| 486 |
# -----------------------------------------------------------------
|
| 487 |
|
| 488 |
+
# CRITICAL: Use sampling (same as training), NOT argmax!
|
| 489 |
+
# argmax always picks the single highest prob, collapsing a
|
| 490 |
+
# balanced policy like [0.35, 0.25, 0.22, 0.18] into "always
|
| 491 |
+
# INVESTIGATE". Sampling reproduces the trained behavior.
|
| 492 |
+
idx = int(np.random.choice(4, p=probs))
|
| 493 |
act = _ACTION_NAMES[idx]
|
| 494 |
conf = round(float(probs[idx]) * 100, 1)
|
| 495 |
return {
|