finalRLEnv / env /action_space.py
garvitsachdeva's picture
SpindleFlow RL — periodic push + log persistence
02ff91f
"""
Hierarchical Factored Action Space.
4 heads decoded sequentially at each step:
Head 1: Meta-action — what high-level thing to do?
Head 2: Specialist selection — which specialist(s) to call?
Head 3: Delegation mode — how to call them?
Head 4: Mode parameters — how many rounds, threshold, etc.?
Design: Sequential decomposition keeps each head's distribution
tractable for PPO. The policy sees a flattened joint action, but
training uses the factored structure.
"""
from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum
from typing import Optional
import numpy as np
class MetaAction(IntEnum):
"""Top-level orchestrator decisions."""
CALL_SPECIALIST = 0 # Call one or more specialists
STOP = 1 # Stop delegation, synthesize output
CALL_MEDIATOR = 2 # Call conflict mediator
CLARIFY_TASK = 3 # Request task clarification (if ambiguous)
DELEGATE_SUBTASK = 4 # Delegate a sub-problem (2nd level)
RETRY_FAILED = 5 # Retry a failed specialist with fallback
PARALLEL_SPAWN = 6 # Spawn parallel specialists
SPAWN_SPECIALIST = 7 # Policy requests a new specialist be created
class DelegationMode(IntEnum):
"""How to execute the selected specialists."""
SEQUENTIAL = 0 # A → B → C (each sees previous output)
PARALLEL = 1 # A, B, C all run simultaneously
FAN_OUT_REDUCE = 2 # A, B, C run → mediator reduces output
ITERATIVE = 3 # Run specialist, check output, loop until threshold
CONDITIONAL = 4 # Run A; if condition met, run B, else C
PRIORITY_QUEUE = 5 # Run in priority order, stop when threshold met
BROADCAST = 6 # Send to all specialists, take first to complete
@dataclass
class FactoredAction:
"""
The complete action decoded from all 4 heads.
This is what gets passed to the environment's step() function.
"""
meta_action: MetaAction
specialist_ids: list[str] # Which specialists to call
delegation_mode: DelegationMode # How to call them
mode_params: dict # Mode-specific parameters
raw_action: Optional[np.ndarray] = None # Raw policy output (for logging)
def is_terminal(self) -> bool:
"""Returns True if this action ends the episode."""
return self.meta_action == MetaAction.STOP
def to_log_dict(self) -> dict:
return {
"meta_action": self.meta_action.name,
"specialists": self.specialist_ids,
"mode": self.delegation_mode.name,
"params": self.mode_params,
}
class ActionDecoder:
"""
Decodes a flat action vector from the policy into a FactoredAction.
Action vector layout:
[0] : meta_action index (int, 0–6)
[1 : 1+max_specialists] : specialist selection (multi-hot float)
[1+max_specialists] : delegation_mode index (int, 0–6)
[2+max_specialists : *] : mode_params (continuous, 4 floats)
Total action dim = 1 + max_specialists + 1 + 4 = max_specialists + 6
"""
NUM_META_ACTIONS = len(MetaAction)
NUM_DELEGATION_MODES = len(DelegationMode)
NUM_MODE_PARAMS = 4
def __init__(self, specialist_ids: list[str], max_specialists: int = 8):
self.specialist_ids = specialist_ids
self.max_specialists = min(len(specialist_ids), max_specialists)
self.action_dim = self.max_specialists + 6
def decode(
self,
action_vector: np.ndarray,
valid_specialist_mask: Optional[np.ndarray] = None,
) -> FactoredAction:
"""
Decode a flat action vector into a FactoredAction.
Args:
action_vector: Flat numpy array from the policy
valid_specialist_mask: Binary mask, 1 = valid, 0 = masked out
(enforces DAG constraints)
"""
action_vector = np.asarray(action_vector, dtype=np.float32)
# Head 1: Meta-action
meta_idx = int(np.clip(round(action_vector[0]), 0, self.NUM_META_ACTIONS - 1))
meta_action = MetaAction(meta_idx)
# Head 2: Specialist selection (multi-hot)
spec_logits = action_vector[1: 1 + self.max_specialists]
if valid_specialist_mask is not None:
spec_logits = spec_logits * valid_specialist_mask[:self.max_specialists]
selected_indices = np.where(spec_logits > 0.0)[0]
if len(selected_indices) == 0 and meta_action == MetaAction.CALL_SPECIALIST:
# Fallback: select the highest-scoring specialist
selected_indices = [int(np.argmax(spec_logits))]
selected_ids = [
self.specialist_ids[i]
for i in selected_indices
if i < len(self.specialist_ids)
]
# Head 3: Delegation mode
mode_idx = int(np.clip(
round(action_vector[1 + self.max_specialists]),
0, self.NUM_DELEGATION_MODES - 1
))
delegation_mode = DelegationMode(mode_idx)
# Head 4: Mode parameters
param_start = 2 + self.max_specialists
raw_params = action_vector[param_start: param_start + self.NUM_MODE_PARAMS]
mode_params = self._decode_mode_params(delegation_mode, raw_params)
return FactoredAction(
meta_action=meta_action,
specialist_ids=selected_ids,
delegation_mode=delegation_mode,
mode_params=mode_params,
raw_action=action_vector,
)
def _decode_mode_params(
self, mode: DelegationMode, raw_params: np.ndarray
) -> dict:
"""Decode mode-specific parameters from the raw continuous params."""
p = np.clip(raw_params, 0.0, 1.0)
if mode == DelegationMode.ITERATIVE:
return {
"max_rounds": int(1 + round(p[0] * 4)), # 1–5 rounds
"quality_threshold": float(0.5 + p[1] * 0.5), # 0.5–1.0
}
elif mode == DelegationMode.PRIORITY_QUEUE:
return {
"stop_threshold": float(0.6 + p[0] * 0.4), # 0.6–1.0
}
elif mode == DelegationMode.CONDITIONAL:
return {
"condition_threshold": float(0.4 + p[0] * 0.6), # 0.4–1.0
}
else:
return {"parallel_budget_ms": int(2000 + p[0] * 6000)}
def get_action_dim(self) -> int:
return self.action_dim
def build_specialist_mask(
self, valid_specialist_ids: list[str]
) -> np.ndarray:
"""Build a binary mask for valid specialist selections."""
mask = np.zeros(self.max_specialists, dtype=np.float32)
valid_set = set(valid_specialist_ids)
for i, sid in enumerate(self.specialist_ids[: self.max_specialists]):
if sid in valid_set:
mask[i] = 1.0
return mask