Spaces:
Runtime error
Runtime error
File size: 6,960 Bytes
02ff91f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | """
Hierarchical Factored Action Space.
4 heads decoded sequentially at each step:
Head 1: Meta-action β what high-level thing to do?
Head 2: Specialist selection β which specialist(s) to call?
Head 3: Delegation mode β how to call them?
Head 4: Mode parameters β how many rounds, threshold, etc.?
Design: Sequential decomposition keeps each head's distribution
tractable for PPO. The policy sees a flattened joint action, but
training uses the factored structure.
"""
from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum
from typing import Optional
import numpy as np
class MetaAction(IntEnum):
"""Top-level orchestrator decisions."""
CALL_SPECIALIST = 0 # Call one or more specialists
STOP = 1 # Stop delegation, synthesize output
CALL_MEDIATOR = 2 # Call conflict mediator
CLARIFY_TASK = 3 # Request task clarification (if ambiguous)
DELEGATE_SUBTASK = 4 # Delegate a sub-problem (2nd level)
RETRY_FAILED = 5 # Retry a failed specialist with fallback
PARALLEL_SPAWN = 6 # Spawn parallel specialists
SPAWN_SPECIALIST = 7 # Policy requests a new specialist be created
class DelegationMode(IntEnum):
"""How to execute the selected specialists."""
SEQUENTIAL = 0 # A β B β C (each sees previous output)
PARALLEL = 1 # A, B, C all run simultaneously
FAN_OUT_REDUCE = 2 # A, B, C run β mediator reduces output
ITERATIVE = 3 # Run specialist, check output, loop until threshold
CONDITIONAL = 4 # Run A; if condition met, run B, else C
PRIORITY_QUEUE = 5 # Run in priority order, stop when threshold met
BROADCAST = 6 # Send to all specialists, take first to complete
@dataclass
class FactoredAction:
"""
The complete action decoded from all 4 heads.
This is what gets passed to the environment's step() function.
"""
meta_action: MetaAction
specialist_ids: list[str] # Which specialists to call
delegation_mode: DelegationMode # How to call them
mode_params: dict # Mode-specific parameters
raw_action: Optional[np.ndarray] = None # Raw policy output (for logging)
def is_terminal(self) -> bool:
"""Returns True if this action ends the episode."""
return self.meta_action == MetaAction.STOP
def to_log_dict(self) -> dict:
return {
"meta_action": self.meta_action.name,
"specialists": self.specialist_ids,
"mode": self.delegation_mode.name,
"params": self.mode_params,
}
class ActionDecoder:
"""
Decodes a flat action vector from the policy into a FactoredAction.
Action vector layout:
[0] : meta_action index (int, 0β6)
[1 : 1+max_specialists] : specialist selection (multi-hot float)
[1+max_specialists] : delegation_mode index (int, 0β6)
[2+max_specialists : *] : mode_params (continuous, 4 floats)
Total action dim = 1 + max_specialists + 1 + 4 = max_specialists + 6
"""
NUM_META_ACTIONS = len(MetaAction)
NUM_DELEGATION_MODES = len(DelegationMode)
NUM_MODE_PARAMS = 4
def __init__(self, specialist_ids: list[str], max_specialists: int = 8):
self.specialist_ids = specialist_ids
self.max_specialists = min(len(specialist_ids), max_specialists)
self.action_dim = self.max_specialists + 6
def decode(
self,
action_vector: np.ndarray,
valid_specialist_mask: Optional[np.ndarray] = None,
) -> FactoredAction:
"""
Decode a flat action vector into a FactoredAction.
Args:
action_vector: Flat numpy array from the policy
valid_specialist_mask: Binary mask, 1 = valid, 0 = masked out
(enforces DAG constraints)
"""
action_vector = np.asarray(action_vector, dtype=np.float32)
# Head 1: Meta-action
meta_idx = int(np.clip(round(action_vector[0]), 0, self.NUM_META_ACTIONS - 1))
meta_action = MetaAction(meta_idx)
# Head 2: Specialist selection (multi-hot)
spec_logits = action_vector[1: 1 + self.max_specialists]
if valid_specialist_mask is not None:
spec_logits = spec_logits * valid_specialist_mask[:self.max_specialists]
selected_indices = np.where(spec_logits > 0.0)[0]
if len(selected_indices) == 0 and meta_action == MetaAction.CALL_SPECIALIST:
# Fallback: select the highest-scoring specialist
selected_indices = [int(np.argmax(spec_logits))]
selected_ids = [
self.specialist_ids[i]
for i in selected_indices
if i < len(self.specialist_ids)
]
# Head 3: Delegation mode
mode_idx = int(np.clip(
round(action_vector[1 + self.max_specialists]),
0, self.NUM_DELEGATION_MODES - 1
))
delegation_mode = DelegationMode(mode_idx)
# Head 4: Mode parameters
param_start = 2 + self.max_specialists
raw_params = action_vector[param_start: param_start + self.NUM_MODE_PARAMS]
mode_params = self._decode_mode_params(delegation_mode, raw_params)
return FactoredAction(
meta_action=meta_action,
specialist_ids=selected_ids,
delegation_mode=delegation_mode,
mode_params=mode_params,
raw_action=action_vector,
)
def _decode_mode_params(
self, mode: DelegationMode, raw_params: np.ndarray
) -> dict:
"""Decode mode-specific parameters from the raw continuous params."""
p = np.clip(raw_params, 0.0, 1.0)
if mode == DelegationMode.ITERATIVE:
return {
"max_rounds": int(1 + round(p[0] * 4)), # 1β5 rounds
"quality_threshold": float(0.5 + p[1] * 0.5), # 0.5β1.0
}
elif mode == DelegationMode.PRIORITY_QUEUE:
return {
"stop_threshold": float(0.6 + p[0] * 0.4), # 0.6β1.0
}
elif mode == DelegationMode.CONDITIONAL:
return {
"condition_threshold": float(0.4 + p[0] * 0.6), # 0.4β1.0
}
else:
return {"parallel_budget_ms": int(2000 + p[0] * 6000)}
def get_action_dim(self) -> int:
return self.action_dim
def build_specialist_mask(
self, valid_specialist_ids: list[str]
) -> np.ndarray:
"""Build a binary mask for valid specialist selections."""
mask = np.zeros(self.max_specialists, dtype=np.float32)
valid_set = set(valid_specialist_ids)
for i, sid in enumerate(self.specialist_ids[: self.max_specialists]):
if sid in valid_set:
mask[i] = 1.0
return mask
|